torch.onnx.export and Phi-2

Exports model Phi-2. We use a dummy model. The main difficulty is to set the dynamic shapes properly.

Model

import copy
from typing import Any, Dict
import onnx
import torch
import transformers
from onnx_array_api.plotting.graphviz_helper import plot_dot
from experimental_experiment.helpers import string_type, pretty_onnx


def get_phi2_untrained(batch_size: int = 2, **kwargs) -> Dict[str, Any]:
    """
    Gets a non initialized model with its inputs

    :param batch_size: batch size
    :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
    :return: dictionary

    See `Phi-2/config.json
    <https://huggingface.co/microsoft/phi-2/blob/main/config.json>`_.
    """
    config = {
        "_name_or_path": "microsoft/phi-2",
        "architectures": ["PhiForCausalLM"],
        "attention_dropout": 0.0,
        "bos_token_id": 50256,
        "embd_pdrop": 0.0,
        "eos_token_id": 50256,
        "hidden_act": "gelu_new",
        "hidden_size": 2560,
        "initializer_range": 0.02,
        "intermediate_size": 10240,
        "layer_norm_eps": 1e-05,
        "max_position_embeddings": 2048,
        "model_type": "phi",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 32,
        "partial_rotary_factor": 0.4,
        "qk_layernorm": False,
        "resid_pdrop": 0.1,
        "rope_scaling": None,
        "rope_theta": 10000.0,
        "tie_word_embeddings": False,
        "torch_dtype": "float16",
        "transformers_version": "4.37.0",
        "use_cache": True,
        "vocab_size": 51200,
    }
    config.update(**kwargs)
    conf = transformers.PhiConfig(**config)
    model = transformers.PhiForCausalLM(conf)
    model.eval()

    batch = torch.export.Dim("batch", min=1, max=1024)
    seq_length = torch.export.Dim("seq_length", min=1, max=4096)
    shapes = {}

    cache = transformers.cache_utils.DynamicCache(config["num_hidden_layers"])
    for i in range(config["num_hidden_layers"]):
        cache.update(
            torch.randn(batch_size, 32, 30, 80), torch.randn(batch_size, 32, 30, 80), i
        )
    cache2 = transformers.cache_utils.DynamicCache(config["num_hidden_layers"])
    for i in range(config["num_hidden_layers"]):
        cache2.update(
            torch.randn(batch_size + 1, 32, 31, 80),
            torch.randn(batch_size + 1, 32, 31, 80),
            i,
        )

    inputs = dict(
        input_ids=torch.randint(0, 50285, (batch_size, 3)).to(torch.int64),
        attention_mask=torch.ones((batch_size, 33)).to(torch.int64),
        past_key_values=cache,
    )
    inputs2 = dict(
        input_ids=torch.randint(0, 50285, (batch_size + 1, 4)).to(torch.int64),
        attention_mask=torch.ones((batch_size + 1, 35)).to(torch.int64),
        past_key_values=cache2,
    )
    n = len(cache.key_cache)
    cache_length = torch.export.Dim("cache_length", min=1, max=4096)
    shapes.update(
        {
            "input_ids": {0: batch, 1: seq_length},
            "attention_mask": {
                0: batch,
                1: torch.export.Dim.DYNAMIC,  # cache_length + seq_length
            },
            "past_key_values": [
                [{0: batch, 2: cache_length} for _ in range(n)],  # 0: batch,
                [{0: batch, 2: cache_length} for _ in range(n)],  # 0: batch,
            ],
        }
    )

    return dict(inputs=inputs, model=model, dynamic_shapes=shapes, inputs2=inputs2)


data = get_phi2_untrained(num_hidden_layers=2)
model = data["model"]
inputs = data["inputs"]
dynamic_shapes = data["dynamic_shapes"]

print("inputs", string_type(inputs, with_shape=True))
print("dynamic_shapes", dynamic_shapes)
inputs dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x80,T1s2x32x30x80], value_cache=#2[T1s2x32x30x80,T1s2x32x30x80]))
dynamic_shapes {'input_ids': {0: <class '__main__.batch'>, 1: <class '__main__.seq_length'>}, 'attention_mask': {0: <class '__main__.batch'>, 1: <_DimHint.DYNAMIC: 3>}, 'past_key_values': [[{0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}, {0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}], [{0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}, {0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}]]}

Let’s check it is working. We need to copy the input before calling the model because it modifies the inputs and they are not properly set up when the export starts.

CausalLMOutputWithPast(loss=None, logits=tensor([[[ 9.6265e-01,  3.7639e-01, -1.3292e+00,  ..., -1.0676e-03,
           4.6708e-01,  3.7059e-01],
         [ 1.5374e+00,  1.5214e+00, -9.4288e-01,  ...,  1.3344e+00,
          -4.0183e-01, -1.3003e-02],
         [ 6.1904e-02,  1.2486e+00, -1.5581e+00,  ...,  1.4394e+00,
           1.8182e-01, -9.7775e-01]],

        [[-6.1760e-01,  6.2589e-01, -1.4408e+00,  ..., -3.2613e-01,
           7.2200e-01, -9.4777e-01],
         [ 2.0299e-01,  1.7264e+00, -4.9078e-01,  ..., -1.0258e+00,
           1.3857e+00,  2.7226e-01],
         [ 1.1295e-01, -1.7574e-01, -8.3614e-01,  ..., -1.0082e+00,
           1.8203e+00,  5.9326e-02]]], grad_fn=<ViewBackward0>), past_key_values=DynamicCache(), hidden_states=None, attentions=None)

Export

Let’s export with torch.onnx.export().

try:
    torch.onnx.export(
        copy.deepcopy(model),
        (),
        kwargs=copy.deepcopy(inputs),
        dynamic_shapes=dynamic_shapes,
        dynamo=True,
    )
except Exception as e:
    print(f"export failed due to {e}")
/home/xadupre/github/onnxscript/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
  param_schemas = callee.param_schemas()
/home/xadupre/github/onnxscript/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
  param_schemas = callee.param_schemas()
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export(..., strict=False)`... ❌
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export`...
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export`... ❌
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with Torch Script...
/home/xadupre/vv/this312/lib/python3.12/site-packages/transformers/modeling_utils.py:5055: FutureWarning: `_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead
  warnings.warn(
/home/xadupre/vv/this312/lib/python3.12/site-packages/transformers/cache_utils.py:460: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  or len(self.key_cache[layer_idx]) == 0  # the layer has no cache
/home/xadupre/vv/this312/lib/python3.12/site-packages/transformers/models/phi/modeling_phi.py:1126: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if sequence_length != 1:
/home/xadupre/vv/this312/lib/python3.12/site-packages/transformers/cache_utils.py:444: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  len(self.key_cache[layer_idx]) == 0
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with Torch Script... ❌
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with internal Dynamo apis...
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with internal Dynamo apis... ❌
export failed due to Failed to export the model with torch.export. This is step 1/3 of exporting the model to ONNX. Next steps:
- Modify the model code for `torch.export.export` to succeed. Refer to https://pytorch.org/docs/stable/generated/exportdb/index.html for more information.
- Debug `torch.export.export` and summit a PR to PyTorch.
- Create an issue in the PyTorch GitHub repository against the *torch.export* component and attach the full error stack as well as reproduction scripts.

## Exception summary

<class 'torch._dynamo.exc.UserError'>: Cannot associate shape [[{0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}, {0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}], [{0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}, {0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}]] specified at `dynamic_shapes['past_key_values']` to non-tensor type <class 'transformers.cache_utils.DynamicCache'> at `inputs['past_key_values']` (expected None)
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

(Refer to the full stack trace above for more information.)

The export fails for a couple of reason but it is possible to patch the code to make it work. All those modifications are put in place by onnx_export_errors and reverted after the export is done. Among other things, this function registers serialization functions as shown in example Export a model using a custom type as input.

from experimental_experiment.torch_interpreter.onnx_export_errors import (
    bypass_export_some_errors,
)

with bypass_export_some_errors(
    patch_transformers=True, replace_dynamic_cache=True, verbose=1
) as modificator:
    print("inputs before", string_type(inputs, with_shape=True))
    inputs = modificator(inputs)
    print("inputs after", string_type(inputs, with_shape=True))
    # ep = torch.export.export(model, (), inputs, dynamic_shapes=dynamic_shapes, strict=False)
    ep = torch.onnx.export(
        model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=dynamic_shapes, dynamo=True
    )
    ep.optimize()
    ep.save("plot_exporter_recipes_oe_phi2.onnx")
[bypass_export_some_errors] replace torch.jit.isinstance, torch._dynamo.mark_static_address
[bypass_export_some_errors] register MambaCache
[bypass_export_some_errors] register DynamicCache
[bypass_export_some_errors] register patched_DynamicCache
[bypass_export_some_errors] patch sympy
[bypass_export_some_errors] patch pytorch
[bypass_export_some_errors] patch transformers
[bypass_export_some_errors] replace DynamicCache
inputs before dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x80,T1s2x32x30x80], value_cache=#2[T1s2x32x30x80,T1s2x32x30x80]))
inputs after dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:patched_DynamicCache(key_cache=#2[T1s2x32x30x80,T1s2x32x30x80], value_cache=#2[T1s2x32x30x80,T1s2x32x30x80]))
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 31 of general pattern rewrite rules.
[bypass_export_some_errors] restored sympy functions
[bypass_export_some_errors] restored pytorch functions
[bypass_export_some_errors] restored transformer
[bypass_export_some_errors] restored DynamicCache
[bypass_export_some_errors] unregistered MambaCache
[bypass_export_some_errors] unregistered DynamicCache
[bypass_export_some_errors] unregistered patched_DynamicCache

Exported Model

Let’s display the model.

onx = onnx.load("plot_exporter_recipes_oe_phi2.onnx")
print(pretty_onnx(onx))
opset: domain='pkg.onnxscript.torch_lib.common' version=1
opset: domain='' version=18
opset: domain='pkg.onnxscript.torch_lib' version=1
input: name='input_ids' type=dtype('int64') shape=['s0', '3']
input: name='attention_mask' type=dtype('int64') shape=['s0', 's11 + 3']
input: name='past_key_values_key_cache_0' type=dtype('float32') shape=['s0', 32, 's11', 80]
input: name='past_key_values_key_cache_1' type=dtype('float32') shape=['s0', 32, 's11', 80]
input: name='past_key_values_value_cache_0' type=dtype('float32') shape=['s0', 32, 's11', 80]
input: name='past_key_values_value_cache_1' type=dtype('float32') shape=['s0', 32, 's11', 80]
init: name='model.embed_tokens.weight' type=float32 shape=(51200, 2560)
init: name='model.layers.0.self_attn.q_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.0.self_attn.q_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.0.self_attn.k_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.0.self_attn.k_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.0.self_attn.v_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.0.self_attn.v_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.0.self_attn.dense.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.0.self_attn.dense.bias' type=float32 shape=(2560,)
init: name='model.layers.0.mlp.fc1.weight' type=float32 shape=(10240, 2560)
init: name='model.layers.0.mlp.fc1.bias' type=float32 shape=(10240,)
init: name='model.layers.0.mlp.fc2.weight' type=float32 shape=(2560, 10240)
init: name='model.layers.0.mlp.fc2.bias' type=float32 shape=(2560,)
init: name='model.layers.0.input_layernorm.weight' type=float32 shape=(2560,)
init: name='model.layers.0.input_layernorm.bias' type=float32 shape=(2560,)
init: name='model.layers.1.self_attn.q_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.1.self_attn.q_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.1.self_attn.k_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.1.self_attn.k_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.1.self_attn.v_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.1.self_attn.v_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.1.self_attn.dense.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.1.self_attn.dense.bias' type=float32 shape=(2560,)
init: name='model.layers.1.mlp.fc1.weight' type=float32 shape=(10240, 2560)
init: name='model.layers.1.mlp.fc1.bias' type=float32 shape=(10240,)
init: name='model.layers.1.mlp.fc2.weight' type=float32 shape=(2560, 10240)
init: name='model.layers.1.mlp.fc2.bias' type=float32 shape=(2560,)
init: name='model.layers.1.input_layernorm.weight' type=float32 shape=(2560,)
init: name='model.layers.1.input_layernorm.bias' type=float32 shape=(2560,)
init: name='model.final_layernorm.weight' type=float32 shape=(2560,)
init: name='model.final_layernorm.bias' type=float32 shape=(2560,)
init: name='lm_head.weight' type=float32 shape=(51200, 2560)
init: name='lm_head.bias' type=float32 shape=(51200,)
Constant(value_int=1) -> diagonal
Shape(input_ids, end=1, start=0) -> sym_size_int_64
Shape(input_ids, end=2, start=1) -> sym_size_int_65
  Mul(sym_size_int_64, sym_size_int_65) -> mul_171
Shape(past_key_values_key_cache_0, end=3, start=2) -> sym_size_int_66
  Add(sym_size_int_66, sym_size_int_65) -> add_4
  Concat(sym_size_int_65, add_4, axis=0) -> val_4
Gather(model.embed_tokens.weight, input_ids, axis=0) -> embedding
  LayerNormalization(embedding, model.layers.0.input_layernorm.weight, model.layers.0.input_layernorm.bias, epsilon=0.00, axis=-1) -> layer_norm
Constant(value=1.0) -> val_0
Constant(value=1) -> val_1
  Range(sym_size_int_66, add_4, val_1) -> arange
Constant(value=0) -> dim_0
  Unsqueeze(arange, dim_0) -> unsqueeze
Constant(value=-3.4028234...) -> val_3
  Expand(val_3, val_4) -> full
  Trilu(full, diagonal, upper=1) -> triu
Constant(value=0) -> val_7
Constant(value=1) -> val_8
  Range(val_7, add_4, val_8) -> arange_1
Constant(value=[-1, 1]) -> val_10
  Reshape(arange, val_10, allowzero=0) -> view
    Greater(arange_1, view) -> gt
      Cast(gt, to=1) -> convert_element_type_default
    Mul(triu, convert_element_type_default) -> mul_16
Constant(value=0) -> dim_0_2
  Unsqueeze(mul_16, dim_0_2) -> unsqueeze_3
Constant(value=1) -> dim_0_3
  Unsqueeze(unsqueeze_3, dim_0_3) -> unsqueeze_4
Constant(value=0) -> val_11
Constant(value=2) -> val_19
Constant(value=[1]) -> val_35
Constant(value=[-1]) -> val_36
  Concat(sym_size_int_64, val_35, val_36, val_36, axis=0) -> val_37
    Abs(val_37) -> size_1
    Expand(unsqueeze_4, size_1) -> expand_1
      Shape(expand_1, start=0) -> val_171
  Gather(val_171, val_19, axis=0) -> val_172
Constant(value=1) -> val_54
  Range(val_11, val_172, val_54) -> val_173
Constant(value=1) -> dim_0_4
  Unsqueeze(attention_mask, dim_0_4) -> unsqueeze_5
Constant(value=2) -> dim_0_5
  Unsqueeze(unsqueeze_5, dim_0_5) -> unsqueeze_6
    Cast(unsqueeze_6, to=1) -> convert_element_type_default_1
      Add(expand_1, convert_element_type_default_1) -> add_84
Constant(value=0.0) -> scalar_tensor_default
  Equal(add_84, scalar_tensor_default) -> eq_57
Constant(value=-3.4028234...) -> val_119
  Where(eq_57, val_119, expand_1) -> masked_fill
    Transpose(masked_fill, perm=[2,1,0,3]) -> val_180
Constant(value=-1) -> val_178
  Unsqueeze(val_173, val_178) -> val_179
Transpose(expand_1, perm=[2,1,0,3]) -> val_181
  ScatterND(val_181, val_179, val_180, reduction=b'none') -> val_182
    Transpose(val_182, perm=[2,1,0,3]) -> slice_scatter
      Transpose(slice_scatter, perm=[1,0,2,3]) -> val_192
Shape(expand_1, start=0) -> val_184
  Gather(val_184, val_54, axis=0) -> val_185
  Range(val_11, val_185, val_54) -> val_186
  Unsqueeze(val_186, val_178) -> val_191
Transpose(expand_1, perm=[1,0,2,3]) -> val_193
  ScatterND(val_193, val_191, val_192, reduction=b'none') -> val_194
    Transpose(val_194, perm=[1,0,2,3]) -> slice_scatter_1
Shape(expand_1, start=0) -> val_196
  Gather(val_196, val_11, axis=0) -> val_197
  Range(val_11, val_197, val_54) -> val_198
  Unsqueeze(val_198, val_178) -> val_203
    ScatterND(expand_1, val_203, slice_scatter_1, reduction=b'none') -> slice_scatter_2
Constant(value=1) -> dim_0_8
  Unsqueeze(unsqueeze, dim_0_8) -> unsqueeze_9
    Cast(unsqueeze_9, to=1) -> _to_copy_1
Constant(value=[[[1.0], [...) -> _to_copy_2
  MatMul(_to_copy_2, _to_copy_1) -> matmul
    Transpose(matmul, perm=[0,2,1]) -> transpose
      Concat(transpose, transpose, axis=-1) -> cat
        Cos(cat) -> cos
        Sin(cat) -> sin
Constant(value=[2560]) -> val_235
  Concat(mul_171, val_235, axis=0) -> val_236
    Reshape(layer_norm, val_236, allowzero=0) -> view_1
Transpose(model.layers.0.self_attn.q_proj.weight, perm=[1,0]) -> t
  Gemm(view_1, t, model.layers.0.self_attn.q_proj.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_238
  Reshape(addmm, val_238, allowzero=0) -> view_2
Concat(mul_171, val_235, axis=0) -> val_240
  Reshape(layer_norm, val_240, allowzero=0) -> view_3
Transpose(model.layers.0.self_attn.k_proj.weight, perm=[1,0]) -> t_1
  Gemm(view_3, t_1, model.layers.0.self_attn.k_proj.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_1
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_242
  Reshape(addmm_1, val_242, allowzero=0) -> view_4
Concat(mul_171, val_235, axis=0) -> val_244
  Reshape(layer_norm, val_244, allowzero=0) -> view_5
Transpose(model.layers.0.self_attn.v_proj.weight, perm=[1,0]) -> t_2
  Gemm(view_5, t_2, model.layers.0.self_attn.v_proj.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_2
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_246
  Reshape(addmm_2, val_246, allowzero=0) -> view_6
Constant(value=[32]) -> val_248
Constant(value=[80]) -> val_249
  Concat(sym_size_int_64, sym_size_int_65, val_248, val_249, axis=0) -> val_250
    Reshape(view_2, val_250, allowzero=0) -> view_7
      Transpose(view_7, perm=[0,2,1,3]) -> transpose_1
  Concat(sym_size_int_64, sym_size_int_65, val_248, val_249, axis=0) -> val_252
    Reshape(view_4, val_252, allowzero=0) -> view_8
      Transpose(view_8, perm=[0,2,1,3]) -> transpose_2
  Concat(sym_size_int_64, sym_size_int_65, val_248, val_249, axis=0) -> val_254
    Reshape(view_6, val_254, allowzero=0) -> view_9
      Transpose(view_9, perm=[0,2,1,3]) -> transpose_3
        Concat(past_key_values_value_cache_0, transpose_3, axis=-2) -> cat_6
Constant(value=[0]) -> val_258
Constant(value=[32]) -> val_262
Constant(value=[3]) -> val_265
Constant(value_ints=[1]) -> val_266
  Slice(transpose_1, val_258, val_262, val_265, val_266) -> slice_24
Constant(value=[32]) -> val_269
Constant(value=[922337203...) -> val_272
Constant(value=[3]) -> val_275
Constant(value_ints=[1]) -> val_276
  Slice(transpose_1, val_269, val_272, val_275, val_276) -> slice_25
Constant(value=[0]) -> val_279
Constant(value=[32]) -> val_282
Constant(value=[3]) -> val_285
Constant(value_ints=[1]) -> val_286
  Slice(transpose_2, val_279, val_282, val_285, val_286) -> slice_26
Constant(value=[32]) -> val_289
Constant(value=[922337203...) -> val_292
Constant(value=[3]) -> val_295
Constant(value_ints=[1]) -> val_296
  Slice(transpose_2, val_289, val_292, val_295, val_296) -> slice_27
Constant(value=1) -> dim_0_9
  Unsqueeze(cos, dim_0_9) -> unsqueeze_10
    Mul(slice_24, unsqueeze_10) -> mul_233
Constant(value=1) -> dim_0_10
  Unsqueeze(sin, dim_0_10) -> unsqueeze_11
Constant(value=[0]) -> val_299
Constant(value=[16]) -> val_303
Constant(value=[3]) -> val_306
Constant(value_ints=[1]) -> val_307
  Slice(slice_24, val_299, val_303, val_306, val_307) -> slice_28
Constant(value=[16]) -> val_310
Constant(value=[922337203...) -> val_313
Constant(value=[3]) -> val_316
Constant(value_ints=[1]) -> val_317
  Slice(slice_24, val_310, val_313, val_316, val_317) -> slice_29
    Neg(slice_29) -> neg
    Concat(neg, slice_28, axis=-1) -> cat_1
    Mul(cat_1, unsqueeze_11) -> mul_250
      Add(mul_233, mul_250) -> add_306
    Concat(add_306, slice_25, axis=-1) -> cat_3
      Shape(cat_3, start=0) -> val_368
    Mul(slice_26, unsqueeze_10) -> mul_258
Constant(value=[0]) -> val_320
Constant(value=[16]) -> val_323
Constant(value=[3]) -> val_326
Constant(value_ints=[1]) -> val_327
  Slice(slice_26, val_320, val_323, val_326, val_327) -> slice_30
Constant(value=[16]) -> val_330
Constant(value=[922337203...) -> val_333
Constant(value=[3]) -> val_336
Constant(value_ints=[1]) -> val_337
  Slice(slice_26, val_330, val_333, val_336, val_337) -> slice_31
    Neg(slice_31) -> neg_1
    Concat(neg_1, slice_30, axis=-1) -> cat_2
    Mul(cat_2, unsqueeze_11) -> mul_275
      Add(mul_258, mul_275) -> add_342
    Concat(add_342, slice_27, axis=-1) -> cat_4
      Concat(past_key_values_key_cache_0, cat_4, axis=-2) -> cat_5
        Shape(cat_5, start=0) -> val_377
Constant(value_ints=[-1]) -> val_369
  Gather(val_368, val_369, axis=0) -> val_370
    Cast(val_370, to=1) -> val_371
      Sqrt(val_371) -> val_374
Constant(value=1.0) -> val_373
  Div(val_373, val_374) -> val_375
    Sqrt(val_375) -> val_390
      Mul(cat_3, val_390) -> val_391
Constant(value_ints=[9223372036854775807]) -> val_378
  Slice(val_377, val_36, val_378) -> val_379
Constant(value=[-2]) -> val_380
  Slice(val_377, val_380, val_36) -> val_381
Constant(value_ints=[-9223372036854775808]) -> val_382
  Slice(val_377, val_382, val_380) -> val_383
    Concat(val_383, val_379, val_381, axis=0) -> val_388
Constant(value_ints=[-1]) -> val_384
  Concat(val_384, val_381, val_379, axis=0) -> val_385
    Reshape(cat_5, val_385, allowzero=0) -> val_386
      Transpose(val_386, perm=[0,2,1]) -> val_387
      Reshape(val_387, val_388, allowzero=0) -> val_389
    Sqrt(val_375) -> val_392
      Mul(val_389, val_392) -> val_393
        MatMul(val_391, val_393) -> val_394
      Add(val_394, slice_scatter_2) -> val_395
        Softmax(val_395, axis=-1) -> val_396
          MatMul(val_396, cat_6) -> scaled_dot_product_attention
            Transpose(scaled_dot_product_attention, perm=[0,2,1,3]) -> transpose_4
  Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_399
    Reshape(transpose_4, val_399, allowzero=0) -> view_10
  Concat(mul_171, val_235, axis=0) -> val_401
    Reshape(view_10, val_401, allowzero=0) -> view_11
Transpose(model.layers.0.self_attn.dense.weight, perm=[1,0]) -> t_3
  Gemm(view_11, t_3, model.layers.0.self_attn.dense.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_3
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_403
  Reshape(addmm_3, val_403, allowzero=0) -> view_12
Concat(mul_171, val_235, axis=0) -> val_405
  Reshape(layer_norm, val_405, allowzero=0) -> view_13
Transpose(model.layers.0.mlp.fc1.weight, perm=[1,0]) -> t_4
  Gemm(view_13, t_4, model.layers.0.mlp.fc1.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_4
Constant(value=[10240]) -> val_407
  Concat(sym_size_int_64, sym_size_int_65, val_407, axis=0) -> val_408
    Reshape(addmm_4, val_408, allowzero=0) -> view_14
Constant(value=0.5) -> val_410
  Mul(view_14, val_410) -> mul_356
Constant(value=3.0) -> val_411
  Pow(view_14, val_411) -> pow_1
Constant(value=0.04471499...) -> val_412
  Mul(pow_1, val_412) -> mul_363
    Add(view_14, mul_363) -> add_438
Constant(value=0.79788458...) -> val_413
  Mul(add_438, val_413) -> mul_370
    Tanh(mul_370) -> tanh
  Add(tanh, val_0) -> add_451
    Mul(mul_356, add_451) -> mul_380
  Concat(mul_171, val_407, axis=0) -> val_414
    Reshape(mul_380, val_414, allowzero=0) -> view_15
Transpose(model.layers.0.mlp.fc2.weight, perm=[1,0]) -> t_5
  Gemm(view_15, t_5, model.layers.0.mlp.fc2.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_5
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_416
  Reshape(addmm_5, val_416, allowzero=0) -> view_16
    Add(view_12, view_16) -> add_474
  Add(add_474, embedding) -> add_479
    LayerNormalization(add_479, model.layers.1.input_layernorm.weight, model.layers.1.input_layernorm.bias, epsilon=0.00, axis=-1) -> layer_norm_1
  Concat(mul_171, val_235, axis=0) -> val_418
    Reshape(layer_norm_1, val_418, allowzero=0) -> view_17
Transpose(model.layers.1.self_attn.q_proj.weight, perm=[1,0]) -> t_6
  Gemm(view_17, t_6, model.layers.1.self_attn.q_proj.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_6
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_420
  Reshape(addmm_6, val_420, allowzero=0) -> view_18
Concat(mul_171, val_235, axis=0) -> val_422
  Reshape(layer_norm_1, val_422, allowzero=0) -> view_19
Transpose(model.layers.1.self_attn.k_proj.weight, perm=[1,0]) -> t_7
  Gemm(view_19, t_7, model.layers.1.self_attn.k_proj.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_7
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_424
  Reshape(addmm_7, val_424, allowzero=0) -> view_20
Concat(mul_171, val_235, axis=0) -> val_426
  Reshape(layer_norm_1, val_426, allowzero=0) -> view_21
Transpose(model.layers.1.self_attn.v_proj.weight, perm=[1,0]) -> t_8
  Gemm(view_21, t_8, model.layers.1.self_attn.v_proj.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_8
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_428
  Reshape(addmm_8, val_428, allowzero=0) -> view_22
Concat(sym_size_int_64, sym_size_int_65, val_248, val_249, axis=0) -> val_430
  Reshape(view_18, val_430, allowzero=0) -> view_23
    Transpose(view_23, perm=[0,2,1,3]) -> transpose_5
  Concat(sym_size_int_64, sym_size_int_65, val_248, val_249, axis=0) -> val_432
    Reshape(view_20, val_432, allowzero=0) -> view_24
      Transpose(view_24, perm=[0,2,1,3]) -> transpose_6
  Concat(sym_size_int_64, sym_size_int_65, val_248, val_249, axis=0) -> val_434
    Reshape(view_22, val_434, allowzero=0) -> view_25
      Transpose(view_25, perm=[0,2,1,3]) -> transpose_7
        Concat(past_key_values_value_cache_1, transpose_7, axis=-2) -> cat_12
Constant(value=[0]) -> val_438
Constant(value=[32]) -> val_441
Constant(value=[3]) -> val_444
Constant(value_ints=[1]) -> val_445
  Slice(transpose_5, val_438, val_441, val_444, val_445) -> slice_38
Constant(value=[32]) -> val_448
Constant(value=[922337203...) -> val_451
Constant(value=[3]) -> val_454
Constant(value_ints=[1]) -> val_455
  Slice(transpose_5, val_448, val_451, val_454, val_455) -> slice_39
Constant(value=[0]) -> val_458
Constant(value=[32]) -> val_461
Constant(value=[3]) -> val_464
Constant(value_ints=[1]) -> val_465
  Slice(transpose_6, val_458, val_461, val_464, val_465) -> slice_40
Constant(value=[32]) -> val_468
Constant(value=[922337203...) -> val_471
Constant(value=[3]) -> val_474
Constant(value_ints=[1]) -> val_475
  Slice(transpose_6, val_468, val_471, val_474, val_475) -> slice_41
Constant(value=1) -> dim_0_11
  Unsqueeze(cos, dim_0_11) -> unsqueeze_12
    Mul(slice_38, unsqueeze_12) -> mul_474
Constant(value=1) -> dim_0_12
  Unsqueeze(sin, dim_0_12) -> unsqueeze_13
Constant(value=[0]) -> val_478
Constant(value=[16]) -> val_481
Constant(value=[3]) -> val_484
Constant(value_ints=[1]) -> val_485
  Slice(slice_38, val_478, val_481, val_484, val_485) -> slice_42
Constant(value=[16]) -> val_488
Constant(value=[922337203...) -> val_491
Constant(value=[3]) -> val_494
Constant(value_ints=[1]) -> val_495
  Slice(slice_38, val_488, val_491, val_494, val_495) -> slice_43
    Neg(slice_43) -> neg_2
    Concat(neg_2, slice_42, axis=-1) -> cat_7
    Mul(cat_7, unsqueeze_13) -> mul_491
      Add(mul_474, mul_491) -> add_604
    Concat(add_604, slice_39, axis=-1) -> cat_9
      Shape(cat_9, start=0) -> val_546
    Mul(slice_40, unsqueeze_12) -> mul_499
Constant(value=[0]) -> val_498
Constant(value=[16]) -> val_501
Constant(value=[3]) -> val_504
Constant(value_ints=[1]) -> val_505
  Slice(slice_40, val_498, val_501, val_504, val_505) -> slice_44
Constant(value=[16]) -> val_508
Constant(value=[922337203...) -> val_511
Constant(value=[3]) -> val_514
Constant(value_ints=[1]) -> val_515
  Slice(slice_40, val_508, val_511, val_514, val_515) -> slice_45
    Neg(slice_45) -> neg_3
    Concat(neg_3, slice_44, axis=-1) -> cat_8
    Mul(cat_8, unsqueeze_13) -> mul_516
      Add(mul_499, mul_516) -> add_640
    Concat(add_640, slice_41, axis=-1) -> cat_10
      Concat(past_key_values_key_cache_1, cat_10, axis=-2) -> cat_11
        Shape(cat_11, start=0) -> val_555
  Slice(val_555, val_380, val_36) -> val_558
Constant(value_ints=[-1]) -> val_547
  Gather(val_546, val_547, axis=0) -> val_548
    Cast(val_548, to=1) -> val_549
      Sqrt(val_549) -> val_552
Constant(value=1.0) -> val_551
  Div(val_551, val_552) -> val_553
    Sqrt(val_553) -> val_567
      Mul(cat_9, val_567) -> val_568
Constant(value_ints=[9223372036854775807]) -> val_556
  Slice(val_555, val_36, val_556) -> val_557
Constant(value_ints=[-9223372036854775808]) -> val_559
  Slice(val_555, val_559, val_380) -> val_560
    Concat(val_560, val_557, val_558, axis=0) -> val_565
Constant(value_ints=[-1]) -> val_561
  Concat(val_561, val_558, val_557, axis=0) -> val_562
    Reshape(cat_11, val_562, allowzero=0) -> val_563
      Transpose(val_563, perm=[0,2,1]) -> val_564
      Reshape(val_564, val_565, allowzero=0) -> val_566
    Sqrt(val_553) -> val_569
      Mul(val_566, val_569) -> val_570
        MatMul(val_568, val_570) -> val_571
      Add(val_571, slice_scatter_2) -> val_572
        Softmax(val_572, axis=-1) -> val_573
          MatMul(val_573, cat_12) -> scaled_dot_product_attention_1
            Transpose(scaled_dot_product_attention_1, perm=[0,2,1,3]) -> transpose_8
  Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_576
    Reshape(transpose_8, val_576, allowzero=0) -> view_26
  Concat(mul_171, val_235, axis=0) -> val_578
    Reshape(view_26, val_578, allowzero=0) -> view_27
Transpose(model.layers.1.self_attn.dense.weight, perm=[1,0]) -> t_9
  Gemm(view_27, t_9, model.layers.1.self_attn.dense.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_9
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_580
  Reshape(addmm_9, val_580, allowzero=0) -> view_28
Concat(mul_171, val_235, axis=0) -> val_582
  Reshape(layer_norm_1, val_582, allowzero=0) -> view_29
Transpose(model.layers.1.mlp.fc1.weight, perm=[1,0]) -> t_10
  Gemm(view_29, t_10, model.layers.1.mlp.fc1.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_10
Concat(sym_size_int_64, sym_size_int_65, val_407, axis=0) -> val_584
  Reshape(addmm_10, val_584, allowzero=0) -> view_30
  Mul(view_30, val_410) -> mul_597
Pow(view_30, val_411) -> pow_2
  Mul(pow_2, val_412) -> mul_604
    Add(view_30, mul_604) -> add_736
  Mul(add_736, val_413) -> mul_611
    Tanh(mul_611) -> tanh_1
  Add(tanh_1, val_0) -> add_749
    Mul(mul_597, add_749) -> mul_621
  Concat(mul_171, val_407, axis=0) -> val_586
    Reshape(mul_621, val_586, allowzero=0) -> view_31
Transpose(model.layers.1.mlp.fc2.weight, perm=[1,0]) -> t_11
  Gemm(view_31, t_11, model.layers.1.mlp.fc2.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_11
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_588
  Reshape(addmm_11, val_588, allowzero=0) -> view_32
    Add(view_28, view_32) -> add_772
    Add(add_772, add_479) -> add_777
      LayerNormalization(add_777, model.final_layernorm.weight, model.final_layernorm.bias, epsilon=0.00, axis=-1) -> layer_norm_2
  Concat(mul_171, val_235, axis=0) -> val_620
    Reshape(layer_norm_2, val_620, allowzero=0) -> view_33
Transpose(lm_head.weight, perm=[1,0]) -> t_12
  Gemm(view_33, t_12, lm_head.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_12
Constant(value=[51200]) -> val_622
  Concat(sym_size_int_64, sym_size_int_65, val_622, axis=0) -> val_623
    Reshape(addmm_12, val_623, allowzero=0) -> view_34
output: name='view_34' type=dtype('float32') shape=['', '', '']
output: name='cat_5' type=dtype('float32') shape=['s0', 32, '', 80]
output: name='cat_6' type=dtype('float32') shape=['s0', 32, '', 80]
output: name='cat_11' type=dtype('float32') shape=['s0', 32, '', 80]
output: name='cat_12' type=dtype('float32') shape=['s0', 32, '', 80]

Visually.

plot exporter recipes oe phi2

Total running time of the script: (0 minutes 16.675 seconds)

Related examples

to_onnx and Phi-2

to_onnx and Phi-2

to_onnx and submodules from LLMs

to_onnx and submodules from LLMs

torch.onnx.export and a model with a test

torch.onnx.export and a model with a test

Gallery generated by Sphinx-Gallery