torch.onnx.export and Phi-2

Exports model Phi-2. We use a dummy model. The main difficulty is to set the dynamic shapes properly.

Model

import copy
from typing import Any, Dict
import onnx
import torch
import transformers
from onnx_array_api.plotting.graphviz_helper import plot_dot
from experimental_experiment.helpers import string_type, pretty_onnx


def get_phi2_untrained(batch_size: int = 2, **kwargs) -> Dict[str, Any]:
    """
    Gets a non initialized model with its inputs

    :param batch_size: batch size
    :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
    :return: dictionary

    See `Phi-2/config.json
    <https://huggingface.co/microsoft/phi-2/blob/main/config.json>`_.
    """
    config = {
        "_name_or_path": "microsoft/phi-2",
        "architectures": ["PhiForCausalLM"],
        "attention_dropout": 0.0,
        "bos_token_id": 50256,
        "embd_pdrop": 0.0,
        "eos_token_id": 50256,
        "hidden_act": "gelu_new",
        "hidden_size": 2560,
        "initializer_range": 0.02,
        "intermediate_size": 10240,
        "layer_norm_eps": 1e-05,
        "max_position_embeddings": 2048,
        "model_type": "phi",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 32,
        "partial_rotary_factor": 0.4,
        "qk_layernorm": False,
        "resid_pdrop": 0.1,
        "rope_scaling": None,
        "rope_theta": 10000.0,
        "tie_word_embeddings": False,
        "torch_dtype": "float16",
        "transformers_version": "4.37.0",
        "use_cache": True,
        "vocab_size": 51200,
    }
    config.update(**kwargs)
    conf = transformers.PhiConfig(**config)
    model = transformers.PhiForCausalLM(conf)
    model.eval()

    batch = torch.export.Dim("batch", min=1, max=1024)
    seq_length = torch.export.Dim("seq_length", min=1, max=4096)
    shapes = {}

    cache = transformers.cache_utils.DynamicCache(config["num_hidden_layers"])
    for i in range(config["num_hidden_layers"]):
        cache.update(
            torch.randn(batch_size, 32, 30, 80), torch.randn(batch_size, 32, 30, 80), i
        )
    cache2 = transformers.cache_utils.DynamicCache(config["num_hidden_layers"])
    for i in range(config["num_hidden_layers"]):
        cache2.update(
            torch.randn(batch_size + 1, 32, 31, 80),
            torch.randn(batch_size + 1, 32, 31, 80),
            i,
        )

    inputs = dict(
        input_ids=torch.randint(0, 50285, (batch_size, 3)).to(torch.int64),
        attention_mask=torch.ones((batch_size, 33)).to(torch.int64),
        past_key_values=cache,
    )
    inputs2 = dict(
        input_ids=torch.randint(0, 50285, (batch_size + 1, 4)).to(torch.int64),
        attention_mask=torch.ones((batch_size + 1, 35)).to(torch.int64),
        past_key_values=cache2,
    )
    n = len(cache.key_cache)
    cache_length = torch.export.Dim("cache_length", min=1, max=4096)
    shapes.update(
        {
            "input_ids": {0: batch, 1: seq_length},
            "attention_mask": {
                0: batch,
                1: torch.export.Dim.DYNAMIC,  # cache_length + seq_length
            },
            "past_key_values": [
                [{0: batch, 2: cache_length} for _ in range(n)],  # 0: batch,
                [{0: batch, 2: cache_length} for _ in range(n)],  # 0: batch,
            ],
        }
    )

    return dict(inputs=inputs, model=model, dynamic_shapes=shapes, inputs2=inputs2)


data = get_phi2_untrained(num_hidden_layers=2)
model = data["model"]
inputs = data["inputs"]
dynamic_shapes = data["dynamic_shapes"]

print("inputs", string_type(inputs, with_shape=True))
print("dynamic_shapes", dynamic_shapes)
inputs dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x80,T1s2x32x30x80], value_cache=#2[T1s2x32x30x80,T1s2x32x30x80]))
dynamic_shapes {'input_ids': {0: <class '__main__.batch'>, 1: <class '__main__.seq_length'>}, 'attention_mask': {0: <class '__main__.batch'>, 1: <_DimHint.DYNAMIC: 3>}, 'past_key_values': [[{0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}, {0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}], [{0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}, {0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}]]}

Let’s check it is working. We need to copy the input before calling the model because it modifies the inputs and they are not properly set up when the export starts.

CausalLMOutputWithPast(loss=None, logits=tensor([[[ 0.7979,  0.3673,  1.7276,  ...,  1.0755, -0.5409,  1.7213],
         [ 0.0160, -1.5066,  2.6774,  ..., -0.3734, -1.6701,  0.8857],
         [-0.1865, -0.7282,  1.2860,  ...,  1.8946,  0.0776,  0.3918]],

        [[-0.5348,  0.1761,  1.3224,  ..., -0.0308,  0.6626, -0.8239],
         [ 1.2772, -0.2547,  0.2357,  ...,  0.4670, -0.8453,  0.2779],
         [-0.7813, -1.1209,  1.7710,  ..., -0.2175, -1.8416,  0.2227]]],
       grad_fn=<ViewBackward0>), past_key_values=DynamicCache(), hidden_states=None, attentions=None)

Export

Let’s export with torch.onnx.export().

try:
    torch.onnx.export(
        copy.deepcopy(model),
        (),
        kwargs=copy.deepcopy(inputs),
        dynamic_shapes=dynamic_shapes,
        dynamo=True,
    )
except Exception as e:
    print(f"export failed due to {e}")
/home/xadupre/github/onnxscript/onnxscript/converter.py:823: FutureWarning: 'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
  param_schemas = callee.param_schemas()
/home/xadupre/github/onnxscript/onnxscript/converter.py:823: FutureWarning: 'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
  param_schemas = callee.param_schemas()
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export(..., strict=False)`... ❌
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export`...
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export`... ❌
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with Torch Script...
/home/xadupre/vv/this312/lib/python3.12/site-packages/transformers/cache_utils.py:460: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  or len(self.key_cache[layer_idx]) == 0  # the layer has no cache
/home/xadupre/vv/this312/lib/python3.12/site-packages/transformers/models/phi/modeling_phi.py:703: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if sequence_length != 1:
/home/xadupre/vv/this312/lib/python3.12/site-packages/transformers/cache_utils.py:444: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  len(self.key_cache[layer_idx]) == 0




def forward(self, arg0_1: "f32[51200]", arg1_1: "f32[51200, 2560]", arg2_1: "f32[51200, 2560]", arg3_1: "f32[2560]", arg4_1: "f32[2560]", arg5_1: "f32[2560]", arg6_1: "f32[2560]", arg7_1: "f32[2560]", arg8_1: "f32[2560, 2560]", arg9_1: "f32[2560]", arg10_1: "f32[2560, 2560]", arg11_1: "f32[2560]", arg12_1: "f32[2560, 2560]", arg13_1: "f32[2560]", arg14_1: "f32[2560, 2560]", arg15_1: "f32[10240]", arg16_1: "f32[10240, 2560]", arg17_1: "f32[2560]", arg18_1: "f32[2560, 10240]", arg19_1: "f32[2560]", arg20_1: "f32[2560]", arg21_1: "f32[2560]", arg22_1: "f32[2560, 2560]", arg23_1: "f32[2560]", arg24_1: "f32[2560, 2560]", arg25_1: "f32[2560]", arg26_1: "f32[2560, 2560]", arg27_1: "f32[2560]", arg28_1: "f32[2560, 2560]", arg29_1: "f32[10240]", arg30_1: "f32[10240, 2560]", arg31_1: "f32[2560]", arg32_1: "f32[2560, 10240]", arg33_1: "f64[]", arg34_1: "f64[]", arg35_1: "f64[]", arg36_1: "i64[]", arg37_1: "f64[]", arg38_1: "f32[16]", arg39_1: "i64[s0, s1]", arg40_1: "i64[s2, s3]", arg41_1: "f32[s4, s5, s6, s7]", arg42_1: "f32[s8, s9, s10, s11]", arg43_1: "f32[s12, s13, s14, s15]", arg44_1: "f32[s16, s17, s18, s19]"):
    # No stacktrace found for following nodes
    embedding: "f32[s0, s1, 2560]" = torch.ops.aten.embedding.default(arg2_1, arg39_1);  arg2_1 = None
    sym_size: "Sym(s6)" = torch.ops.aten.sym_size.int(arg41_1, 2);  sym_size = None
    sym_size_int: "Sym(s6)" = torch.ops.aten.sym_size.int(arg41_1, 2);  arg41_1 = None
    scalar_tensor: "i64[]" = torch.ops.aten.scalar_tensor.default(sym_size_int, dtype = torch.int64)
    sym_size_1: "Sym(s1)" = torch.ops.aten.sym_size.int(embedding, 1);  embedding = sym_size_1 = None
    sym_size_int_1: "Sym(s1)" = torch.ops.aten.sym_size.int(arg39_1, 1);  arg39_1 = None
    scalar_tensor_1: "i64[]" = torch.ops.aten.scalar_tensor.default(sym_size_int_1, dtype = torch.int64);  sym_size_int_1 = None
    add: "i64[]" = torch.ops.aten.add.Tensor(scalar_tensor, scalar_tensor_1);  scalar_tensor = scalar_tensor_1 = None
    item: "Sym(u0)" = torch.ops.aten.item.default(add);  add = None
    arange = torch.ops.aten.arange.start(sym_size_int, item, device = device(type='cpu'), pin_memory = False);  sym_size_int = item = arange = None

[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with Torch Script... ❌
export failed due to Failed to export the model with torch.export. This is step 1/3 of exporting the model to ONNX. Next steps:
- Modify the model code for `torch.export.export` to succeed. Refer to https://pytorch.org/docs/stable/generated/exportdb/index.html for more information.
- Debug `torch.export.export` and summit a PR to PyTorch.
- Create an issue in the PyTorch GitHub repository against the *torch.export* component and attach the full error stack as well as reproduction scripts.

## Exception summary

<class 'torch._dynamo.exc.UserError'>: Constraints violated (batch)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of batch = L['args'][1]['input_ids'].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
  - Not all values of batch = L['args'][1]['attention_mask'].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
  - Not all values of batch = L['args'][1]['past_key_values']['key_cache'][0].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
  - Not all values of batch = L['args'][1]['past_key_values']['key_cache'][1].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
  - Not all values of batch = L['args'][1]['past_key_values']['value_cache'][0].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
  - Not all values of batch = L['args'][1]['past_key_values']['value_cache'][1].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
Suggested fixes:
  batch = 2

(Refer to the full stack trace above for more information.)

The export fails for a couple of reason but it is possible to patch the code to make it work. All those modifications are put in place by onnx_export_errors and reverted after the export is done. Among other things, this function registers serialization functions as shown in example Export a model using a custom type as input.

from experimental_experiment.torch_interpreter.onnx_export_errors import (
    bypass_export_some_errors,
)

with bypass_export_some_errors(
    patch_transformers=True, replace_dynamic_cache=True, verbose=1
) as modificator:
    print("inputs before", string_type(inputs, with_shape=True))
    inputs = modificator(inputs)
    print("inputs after", string_type(inputs, with_shape=True))
    # ep = torch.export.export(model, (), inputs, dynamic_shapes=dynamic_shapes, strict=False)
    ep = torch.onnx.export(
        model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=dynamic_shapes, dynamo=True
    )
    ep.optimize()
    ep.save("plot_exporter_recipes_oe_phi2.onnx")
[bypass_export_some_errors] replace torch.jit.isinstance, torch._dynamo.mark_static_address
[bypass_export_some_errors] patch sympy
[bypass_export_some_errors] patch pytorch
[bypass_export_some_errors] modifies shape constraints
[bypass_export_some_errors] patch transformers
[bypass_export_some_errors] replace DynamicCache
inputs before dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x80,T1s2x32x30x80], value_cache=#2[T1s2x32x30x80,T1s2x32x30x80]))
inputs after dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:patched_DynamicCache(key_cache=#2[T1s2x32x30x80,T1s2x32x30x80], value_cache=#2[T1s2x32x30x80,T1s2x32x30x80]))
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 31 of general pattern rewrite rules.
[bypass_export_some_errors] restored sympy functions
[bypass_export_some_errors] restored pytorch functions
[bypass_export_some_errors] restored shape constraints
[bypass_export_some_errors] restored transformer
[bypass_export_some_errors] restored DynamicCache

Exported Model

Let’s display the model.

onx = onnx.load("plot_exporter_recipes_oe_phi2.onnx")
print(pretty_onnx(onx))
opset: domain='pkg.onnxscript.torch_lib.common' version=1
opset: domain='' version=18
opset: domain='pkg.onnxscript.torch_lib' version=1
input: name='input_ids' type=dtype('int64') shape=['s0', 's1']
input: name='attention_mask' type=dtype('int64') shape=['s0', 's1 + s11']
input: name='past_key_values_key_cache_0' type=dtype('float32') shape=['s0', 32, 's11', 80]
input: name='past_key_values_key_cache_1' type=dtype('float32') shape=['s0', 32, 's11', 80]
input: name='past_key_values_value_cache_0' type=dtype('float32') shape=['s0', 32, 's11', 80]
input: name='past_key_values_value_cache_1' type=dtype('float32') shape=['s0', 32, 's11', 80]
init: name='model.embed_tokens.weight' type=float32 shape=(51200, 2560)
init: name='model.layers.0.self_attn.q_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.0.self_attn.q_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.0.self_attn.k_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.0.self_attn.k_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.0.self_attn.v_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.0.self_attn.v_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.0.self_attn.dense.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.0.self_attn.dense.bias' type=float32 shape=(2560,)
init: name='model.layers.0.mlp.fc1.weight' type=float32 shape=(10240, 2560)
init: name='model.layers.0.mlp.fc1.bias' type=float32 shape=(10240,)
init: name='model.layers.0.mlp.fc2.weight' type=float32 shape=(2560, 10240)
init: name='model.layers.0.mlp.fc2.bias' type=float32 shape=(2560,)
init: name='model.layers.0.input_layernorm.weight' type=float32 shape=(2560,)
init: name='model.layers.0.input_layernorm.bias' type=float32 shape=(2560,)
init: name='model.layers.1.self_attn.q_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.1.self_attn.q_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.1.self_attn.k_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.1.self_attn.k_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.1.self_attn.v_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.1.self_attn.v_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.1.self_attn.dense.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.1.self_attn.dense.bias' type=float32 shape=(2560,)
init: name='model.layers.1.mlp.fc1.weight' type=float32 shape=(10240, 2560)
init: name='model.layers.1.mlp.fc1.bias' type=float32 shape=(10240,)
init: name='model.layers.1.mlp.fc2.weight' type=float32 shape=(2560, 10240)
init: name='model.layers.1.mlp.fc2.bias' type=float32 shape=(2560,)
init: name='model.layers.1.input_layernorm.weight' type=float32 shape=(2560,)
init: name='model.layers.1.input_layernorm.bias' type=float32 shape=(2560,)
init: name='model.final_layernorm.weight' type=float32 shape=(2560,)
init: name='model.final_layernorm.bias' type=float32 shape=(2560,)
init: name='lm_head.weight' type=float32 shape=(51200, 2560)
init: name='lm_head.bias' type=float32 shape=(51200,)
Constant(value_int=1) -> diagonal
Shape(input_ids, end=1, start=0) -> val_0
  Squeeze(val_0) -> sym_size_int_51
Shape(input_ids, end=2, start=1) -> val_1
  Squeeze(val_1) -> sym_size_int_52
Shape(past_key_values_key_cache_0, end=3, start=2) -> val_2
  Squeeze(val_2) -> sym_size_int_53
    Add(sym_size_int_53, sym_size_int_52) -> add_4
Gather(model.embed_tokens.weight, input_ids, axis=0) -> embedding
  LayerNormalization(embedding, model.layers.0.input_layernorm.weight, model.layers.0.input_layernorm.bias, epsilon=0.00, axis=-1) -> layer_norm
Constant(value=1.0) -> val_3
Constant(value=1) -> val_4
  Range(sym_size_int_53, add_4, val_4) -> arange
Constant(value=0) -> dim_0
  Unsqueeze(arange, dim_0) -> unsqueeze
Constant(value=-3.4028234...) -> val_6
Constant(value=[-1]) -> val_7
  Reshape(sym_size_int_52, val_7, allowzero=0) -> val_8
Constant(value=[-1]) -> val_9
  Reshape(add_4, val_9, allowzero=0) -> val_10
    Concat(val_8, val_10, axis=0) -> val_11
  Expand(val_6, val_11) -> full
  Trilu(full, diagonal, upper=1) -> triu
Constant(value=0) -> val_14
Constant(value=1) -> val_15
  Range(val_14, add_4, val_15) -> arange_1
Constant(value=[-1, 1]) -> val_17
  Reshape(arange, val_17, allowzero=0) -> view
    Greater(arange_1, view) -> gt
      Cast(gt, to=1) -> convert_element_type_default
    Mul(triu, convert_element_type_default) -> mul_16
Constant(value=0) -> dim_0_2
  Unsqueeze(mul_16, dim_0_2) -> unsqueeze_3
Constant(value=1) -> dim_0_3
  Unsqueeze(unsqueeze_3, dim_0_3) -> unsqueeze_4
Constant(value=0) -> val_18
Constant(value=2) -> val_26
Constant(value=[-1]) -> val_42
  Reshape(sym_size_int_51, val_42, allowzero=0) -> val_43
Constant(value=[1]) -> val_44
Constant(value=[-1]) -> val_45
  Concat(val_43, val_44, val_45, val_45, axis=0) -> val_46
    Abs(val_46) -> size_1
    Expand(unsqueeze_4, size_1) -> expand_1
      Shape(expand_1, start=0) -> val_180
  Gather(val_180, val_26, axis=0) -> val_181
Constant(value=1) -> val_63
  Range(val_18, val_181, val_63) -> val_182
Constant(value=1) -> dim_0_4
  Unsqueeze(attention_mask, dim_0_4) -> unsqueeze_5
Constant(value=2) -> dim_0_5
  Unsqueeze(unsqueeze_5, dim_0_5) -> unsqueeze_6
    Cast(unsqueeze_6, to=1) -> convert_element_type_default_1
      Add(expand_1, convert_element_type_default_1) -> add_84
Constant(value=0.0) -> scalar_tensor_default
  Equal(add_84, scalar_tensor_default) -> eq_57
Constant(value=-3.4028234...) -> val_128
  Where(eq_57, val_128, expand_1) -> masked_fill
    Transpose(masked_fill, perm=[2,1,0,3]) -> val_189
Constant(value=-1) -> val_187
  Unsqueeze(val_182, val_187) -> val_188
Transpose(expand_1, perm=[2,1,0,3]) -> val_190
  ScatterND(val_190, val_188, val_189, reduction=b'none') -> val_191
    Transpose(val_191, perm=[2,1,0,3]) -> slice_scatter
      Transpose(slice_scatter, perm=[1,0,2,3]) -> val_201
Shape(expand_1, start=0) -> val_193
  Gather(val_193, val_63, axis=0) -> val_194
  Range(val_18, val_194, val_63) -> val_195
  Unsqueeze(val_195, val_187) -> val_200
Transpose(expand_1, perm=[1,0,2,3]) -> val_202
  ScatterND(val_202, val_200, val_201, reduction=b'none') -> val_203
    Transpose(val_203, perm=[1,0,2,3]) -> slice_scatter_1
Shape(expand_1, start=0) -> val_205
  Gather(val_205, val_18, axis=0) -> val_206
  Range(val_18, val_206, val_63) -> val_207
  Unsqueeze(val_207, val_187) -> val_212
    ScatterND(expand_1, val_212, slice_scatter_1, reduction=b'none') -> slice_scatter_2
Constant(value=1) -> dim_0_8
  Unsqueeze(unsqueeze, dim_0_8) -> unsqueeze_9
    Cast(unsqueeze_9, to=1) -> _to_copy_1
Constant(value=[[[1.0], [...) -> _to_copy_2
  MatMul(_to_copy_2, _to_copy_1) -> matmul
    Transpose(matmul, perm=[0,2,1]) -> transpose
      Concat(transpose, transpose, axis=-1) -> cat
        Cos(cat) -> cos
        Sin(cat) -> sin
Transpose(model.layers.0.self_attn.q_proj.weight, perm=[1,0]) -> val_244
  MatMul(layer_norm, val_244) -> val_245
    Add(val_245, model.layers.0.self_attn.q_proj.bias) -> linear
Constant(value=[-1]) -> val_246
  Reshape(sym_size_int_51, val_246, allowzero=0) -> val_247
Constant(value=[-1]) -> val_248
  Reshape(sym_size_int_52, val_248, allowzero=0) -> val_249
Constant(value=[80]) -> val_250
  Concat(val_247, val_249, val_45, val_250, axis=0) -> val_251
    Reshape(linear, val_251, allowzero=0) -> view_1
      Transpose(view_1, perm=[0,2,1,3]) -> transpose_1
Transpose(model.layers.0.self_attn.k_proj.weight, perm=[1,0]) -> val_253
  MatMul(layer_norm, val_253) -> val_254
    Add(val_254, model.layers.0.self_attn.k_proj.bias) -> linear_1
Constant(value=[-1]) -> val_255
  Reshape(sym_size_int_51, val_255, allowzero=0) -> val_256
Constant(value=[-1]) -> val_257
  Reshape(sym_size_int_52, val_257, allowzero=0) -> val_258
  Concat(val_256, val_258, val_45, val_250, axis=0) -> val_259
    Reshape(linear_1, val_259, allowzero=0) -> view_2
      Transpose(view_2, perm=[0,2,1,3]) -> transpose_2
Transpose(model.layers.0.self_attn.v_proj.weight, perm=[1,0]) -> val_261
  MatMul(layer_norm, val_261) -> val_262
    Add(val_262, model.layers.0.self_attn.v_proj.bias) -> linear_2
Constant(value=[-1]) -> val_263
  Reshape(sym_size_int_51, val_263, allowzero=0) -> val_264
Constant(value=[-1]) -> val_265
  Reshape(sym_size_int_52, val_265, allowzero=0) -> val_266
  Concat(val_264, val_266, val_45, val_250, axis=0) -> val_267
    Reshape(linear_2, val_267, allowzero=0) -> view_3
      Transpose(view_3, perm=[0,2,1,3]) -> transpose_3
        Concat(past_key_values_value_cache_0, transpose_3, axis=-2) -> cat_6
Constant(value=[0]) -> val_271
Constant(value=[32]) -> val_275
Constant(value=[3]) -> val_278
Constant(value_ints=[1]) -> val_279
  Slice(transpose_1, val_271, val_275, val_278, val_279) -> slice_24
Constant(value=[32]) -> val_282
Constant(value=[922337203...) -> val_285
Constant(value=[3]) -> val_288
Constant(value_ints=[1]) -> val_289
  Slice(transpose_1, val_282, val_285, val_288, val_289) -> slice_25
Constant(value=[0]) -> val_292
Constant(value=[32]) -> val_295
Constant(value=[3]) -> val_298
Constant(value_ints=[1]) -> val_299
  Slice(transpose_2, val_292, val_295, val_298, val_299) -> slice_26
Constant(value=[32]) -> val_302
Constant(value=[922337203...) -> val_305
Constant(value=[3]) -> val_308
Constant(value_ints=[1]) -> val_309
  Slice(transpose_2, val_302, val_305, val_308, val_309) -> slice_27
Constant(value=1) -> dim_0_9
  Unsqueeze(cos, dim_0_9) -> unsqueeze_10
    Mul(slice_24, unsqueeze_10) -> mul_214
Constant(value=1) -> dim_0_10
  Unsqueeze(sin, dim_0_10) -> unsqueeze_11
Constant(value=[0]) -> val_312
Constant(value=[16]) -> val_316
Constant(value=[3]) -> val_319
Constant(value_ints=[1]) -> val_320
  Slice(slice_24, val_312, val_316, val_319, val_320) -> slice_28
Constant(value=[16]) -> val_323
Constant(value=[922337203...) -> val_326
Constant(value=[3]) -> val_329
Constant(value_ints=[1]) -> val_330
  Slice(slice_24, val_323, val_326, val_329, val_330) -> slice_29
    Neg(slice_29) -> neg
    Concat(neg, slice_28, axis=-1) -> cat_1
    Mul(cat_1, unsqueeze_11) -> mul_231
      Add(mul_214, mul_231) -> add_288
    Concat(add_288, slice_25, axis=-1) -> cat_3
Mul(slice_26, unsqueeze_10) -> mul_239
Constant(value=[0]) -> val_333
Constant(value=[16]) -> val_336
Constant(value=[3]) -> val_339
Constant(value_ints=[1]) -> val_340
  Slice(slice_26, val_333, val_336, val_339, val_340) -> slice_30
Constant(value=[16]) -> val_343
Constant(value=[922337203...) -> val_346
Constant(value=[3]) -> val_349
Constant(value_ints=[1]) -> val_350
  Slice(slice_26, val_343, val_346, val_349, val_350) -> slice_31
    Neg(slice_31) -> neg_1
    Concat(neg_1, slice_30, axis=-1) -> cat_2
    Mul(cat_2, unsqueeze_11) -> mul_256
  Add(mul_239, mul_256) -> add_324
    Concat(add_324, slice_27, axis=-1) -> cat_4
      Concat(past_key_values_key_cache_0, cat_4, axis=-2) -> cat_5
        Shape(cat_5, start=0) -> val_383
Constant(value_ints=[9223372036854775807]) -> val_384
  Slice(val_383, val_45, val_384) -> val_385
Constant(value=[-2]) -> val_386
  Slice(val_383, val_386, val_45) -> val_387
Constant(value_ints=[-9223372036854775808]) -> val_388
  Slice(val_383, val_388, val_386) -> val_389
    Concat(val_389, val_385, val_387, axis=0) -> val_394
Constant(value_ints=[-1]) -> val_390
  Concat(val_390, val_387, val_385, axis=0) -> val_391
    Reshape(cat_5, val_391, allowzero=0) -> val_392
      Transpose(val_392, perm=[0,2,1]) -> val_393
      Reshape(val_393, val_394, allowzero=0) -> val_395
Constant(value=0.33437013...) -> val_396
  Mul(cat_3, val_396) -> val_397
Constant(value=0.33437013...) -> val_398
  Mul(val_395, val_398) -> val_399
    MatMul(val_397, val_399) -> val_400
      Add(val_400, slice_scatter_2) -> val_401
        Softmax(val_401, axis=-1) -> val_402
          MatMul(val_402, cat_6) -> scaled_dot_product_attention
            Transpose(scaled_dot_product_attention, perm=[0,2,1,3]) -> transpose_4
Constant(value=[-1]) -> val_405
  Reshape(sym_size_int_51, val_405, allowzero=0) -> val_406
Constant(value=[-1]) -> val_407
  Reshape(sym_size_int_52, val_407, allowzero=0) -> val_408
  Concat(val_406, val_408, val_45, axis=0) -> val_409
    Reshape(transpose_4, val_409, allowzero=0) -> view_4
Transpose(model.layers.0.self_attn.dense.weight, perm=[1,0]) -> val_411
  MatMul(view_4, val_411) -> val_412
    Add(val_412, model.layers.0.self_attn.dense.bias) -> linear_3
Transpose(model.layers.0.mlp.fc1.weight, perm=[1,0]) -> val_413
  MatMul(layer_norm, val_413) -> val_414
    Add(val_414, model.layers.0.mlp.fc1.bias) -> linear_4
Constant(value=0.5) -> val_415
  Mul(linear_4, val_415) -> mul_323
Constant(value=3.0) -> val_416
  Pow(linear_4, val_416) -> pow_1
Constant(value=0.04471499...) -> val_417
  Mul(pow_1, val_417) -> mul_330
    Add(linear_4, mul_330) -> add_408
Constant(value=0.79788458...) -> val_418
  Mul(add_408, val_418) -> mul_337
    Tanh(mul_337) -> tanh
  Add(tanh, val_3) -> add_421
    Mul(mul_323, add_421) -> mul_347
Transpose(model.layers.0.mlp.fc2.weight, perm=[1,0]) -> val_419
  MatMul(mul_347, val_419) -> val_420
    Add(val_420, model.layers.0.mlp.fc2.bias) -> linear_5
      Add(linear_3, linear_5) -> add_438
  Add(add_438, embedding) -> add_443
    LayerNormalization(add_443, model.layers.1.input_layernorm.weight, model.layers.1.input_layernorm.bias, epsilon=0.00, axis=-1) -> layer_norm_1
Transpose(model.layers.1.self_attn.q_proj.weight, perm=[1,0]) -> val_421
  MatMul(layer_norm_1, val_421) -> val_422
    Add(val_422, model.layers.1.self_attn.q_proj.bias) -> linear_6
Constant(value=[-1]) -> val_423
  Reshape(sym_size_int_51, val_423, allowzero=0) -> val_424
Constant(value=[-1]) -> val_425
  Reshape(sym_size_int_52, val_425, allowzero=0) -> val_426
  Concat(val_424, val_426, val_45, val_250, axis=0) -> val_427
    Reshape(linear_6, val_427, allowzero=0) -> view_5
      Transpose(view_5, perm=[0,2,1,3]) -> transpose_5
Transpose(model.layers.1.self_attn.k_proj.weight, perm=[1,0]) -> val_429
  MatMul(layer_norm_1, val_429) -> val_430
    Add(val_430, model.layers.1.self_attn.k_proj.bias) -> linear_7
Constant(value=[-1]) -> val_431
  Reshape(sym_size_int_51, val_431, allowzero=0) -> val_432
Constant(value=[-1]) -> val_433
  Reshape(sym_size_int_52, val_433, allowzero=0) -> val_434
  Concat(val_432, val_434, val_45, val_250, axis=0) -> val_435
    Reshape(linear_7, val_435, allowzero=0) -> view_6
      Transpose(view_6, perm=[0,2,1,3]) -> transpose_6
Transpose(model.layers.1.self_attn.v_proj.weight, perm=[1,0]) -> val_437
  MatMul(layer_norm_1, val_437) -> val_438
    Add(val_438, model.layers.1.self_attn.v_proj.bias) -> linear_8
Constant(value=[-1]) -> val_439
  Reshape(sym_size_int_51, val_439, allowzero=0) -> val_440
Constant(value=[-1]) -> val_441
  Reshape(sym_size_int_52, val_441, allowzero=0) -> val_442
  Concat(val_440, val_442, val_45, val_250, axis=0) -> val_443
    Reshape(linear_8, val_443, allowzero=0) -> view_7
      Transpose(view_7, perm=[0,2,1,3]) -> transpose_7
        Concat(past_key_values_value_cache_1, transpose_7, axis=-2) -> cat_12
Constant(value=[0]) -> val_447
Constant(value=[32]) -> val_450
Constant(value=[3]) -> val_453
Constant(value_ints=[1]) -> val_454
  Slice(transpose_5, val_447, val_450, val_453, val_454) -> slice_38
Constant(value=[32]) -> val_457
Constant(value=[922337203...) -> val_460
Constant(value=[3]) -> val_463
Constant(value_ints=[1]) -> val_464
  Slice(transpose_5, val_457, val_460, val_463, val_464) -> slice_39
Constant(value=[0]) -> val_467
Constant(value=[32]) -> val_470
Constant(value=[3]) -> val_473
Constant(value_ints=[1]) -> val_474
  Slice(transpose_6, val_467, val_470, val_473, val_474) -> slice_40
Constant(value=[32]) -> val_477
Constant(value=[922337203...) -> val_480
Constant(value=[3]) -> val_483
Constant(value_ints=[1]) -> val_484
  Slice(transpose_6, val_477, val_480, val_483, val_484) -> slice_41
Constant(value=1) -> dim_0_11
  Unsqueeze(cos, dim_0_11) -> unsqueeze_12
    Mul(slice_38, unsqueeze_12) -> mul_413
Constant(value=1) -> dim_0_12
  Unsqueeze(sin, dim_0_12) -> unsqueeze_13
Constant(value=[0]) -> val_487
Constant(value=[16]) -> val_490
Constant(value=[3]) -> val_493
Constant(value_ints=[1]) -> val_494
  Slice(slice_38, val_487, val_490, val_493, val_494) -> slice_42
Constant(value=[16]) -> val_497
Constant(value=[922337203...) -> val_500
Constant(value=[3]) -> val_503
Constant(value_ints=[1]) -> val_504
  Slice(slice_38, val_497, val_500, val_503, val_504) -> slice_43
    Neg(slice_43) -> neg_2
    Concat(neg_2, slice_42, axis=-1) -> cat_7
    Mul(cat_7, unsqueeze_13) -> mul_430
      Add(mul_413, mul_430) -> add_550
    Concat(add_550, slice_39, axis=-1) -> cat_9
Mul(slice_40, unsqueeze_12) -> mul_438
Constant(value=[0]) -> val_507
Constant(value=[16]) -> val_510
Constant(value=[3]) -> val_513
Constant(value_ints=[1]) -> val_514
  Slice(slice_40, val_507, val_510, val_513, val_514) -> slice_44
Constant(value=[16]) -> val_517
Constant(value=[922337203...) -> val_520
Constant(value=[3]) -> val_523
Constant(value_ints=[1]) -> val_524
  Slice(slice_40, val_517, val_520, val_523, val_524) -> slice_45
    Neg(slice_45) -> neg_3
    Concat(neg_3, slice_44, axis=-1) -> cat_8
    Mul(cat_8, unsqueeze_13) -> mul_455
  Add(mul_438, mul_455) -> add_586
    Concat(add_586, slice_41, axis=-1) -> cat_10
      Concat(past_key_values_key_cache_1, cat_10, axis=-2) -> cat_11
        Shape(cat_11, start=0) -> val_556
  Slice(val_556, val_386, val_45) -> val_559
Constant(value_ints=[9223372036854775807]) -> val_557
  Slice(val_556, val_45, val_557) -> val_558
Constant(value_ints=[-9223372036854775808]) -> val_560
  Slice(val_556, val_560, val_386) -> val_561
    Concat(val_561, val_558, val_559, axis=0) -> val_566
Constant(value_ints=[-1]) -> val_562
  Concat(val_562, val_559, val_558, axis=0) -> val_563
    Reshape(cat_11, val_563, allowzero=0) -> val_564
      Transpose(val_564, perm=[0,2,1]) -> val_565
      Reshape(val_565, val_566, allowzero=0) -> val_567
Constant(value=0.33437013...) -> val_568
  Mul(cat_9, val_568) -> val_569
Constant(value=0.33437013...) -> val_570
  Mul(val_567, val_570) -> val_571
    MatMul(val_569, val_571) -> val_572
      Add(val_572, slice_scatter_2) -> val_573
        Softmax(val_573, axis=-1) -> val_574
          MatMul(val_574, cat_12) -> scaled_dot_product_attention_1
            Transpose(scaled_dot_product_attention_1, perm=[0,2,1,3]) -> transpose_8
Constant(value=[-1]) -> val_577
  Reshape(sym_size_int_51, val_577, allowzero=0) -> val_578
Constant(value=[-1]) -> val_579
  Reshape(sym_size_int_52, val_579, allowzero=0) -> val_580
  Concat(val_578, val_580, val_45, axis=0) -> val_581
    Reshape(transpose_8, val_581, allowzero=0) -> view_8
Transpose(model.layers.1.self_attn.dense.weight, perm=[1,0]) -> val_583
  MatMul(view_8, val_583) -> val_584
    Add(val_584, model.layers.1.self_attn.dense.bias) -> linear_9
Transpose(model.layers.1.mlp.fc1.weight, perm=[1,0]) -> val_585
  MatMul(layer_norm_1, val_585) -> val_586
    Add(val_586, model.layers.1.mlp.fc1.bias) -> linear_10
  Mul(linear_10, val_415) -> mul_522
Pow(linear_10, val_416) -> pow_2
  Mul(pow_2, val_417) -> mul_529
    Add(linear_10, mul_529) -> add_670
  Mul(add_670, val_418) -> mul_536
    Tanh(mul_536) -> tanh_1
  Add(tanh_1, val_3) -> add_683
    Mul(mul_522, add_683) -> mul_546
Transpose(model.layers.1.mlp.fc2.weight, perm=[1,0]) -> val_587
  MatMul(mul_546, val_587) -> val_588
    Add(val_588, model.layers.1.mlp.fc2.bias) -> linear_11
      Add(linear_9, linear_11) -> add_700
    Add(add_700, add_443) -> add_705
      LayerNormalization(add_705, model.final_layernorm.weight, model.final_layernorm.bias, epsilon=0.00, axis=-1) -> layer_norm_2
Transpose(lm_head.weight, perm=[1,0]) -> val_619
  MatMul(layer_norm_2, val_619) -> val_620
    Add(val_620, lm_head.bias) -> linear_12
output: name='linear_12' type=dtype('float32') shape=['s0', 's1', 51200]
output: name='cat_5' type=dtype('float32') shape=['s0', 32, 's1 + s11', 80]
output: name='cat_11' type=dtype('float32') shape=['s0', 32, 's1 + s11', 80]
output: name='cat_6' type=dtype('float32') shape=['s0', 32, 's1 + s11', 80]
output: name='cat_12' type=dtype('float32') shape=['s0', 32, 's1 + s11', 80]

Visually.

plot exporter recipes oe phi2

Total running time of the script: (0 minutes 24.413 seconds)

Related examples

to_onnx and Phi-2

to_onnx and Phi-2

Export Phi-3.5-mini-instruct piece by piece

Export Phi-3.5-mini-instruct piece by piece

Export Phi-3.5-mini-instruct with draft_export

Export Phi-3.5-mini-instruct with draft_export

Export Phi-3.5-mini-instruct with report_exportability

Export Phi-3.5-mini-instruct with report_exportability

Do no use Module as inputs!

Do no use Module as inputs!

Gallery generated by Sphinx-Gallery