LLaMa

Dummy Example

LLaMa

<<<

import numpy as np
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
import torch
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
from experimental_experiment.torch_interpreter import to_onnx


def ids_tensor(shape, vocab_size):
    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(np.random.randint(0, vocab_size - 1))

    return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()


config = LlamaConfig(
    hidden_size=16,
    num_hidden_layers=1,
    vocab_size=1024,
    intermediate_size=16,
    max_position_embeddings=1024,
    num_attention_heads=2,
)
config._attn_implementation = "eager"

with torch.no_grad():

    model = LlamaModel(config)

    batch, seq, vocab_size = 2, 1024, 1024

    input_ids = ids_tensor([batch, seq], vocab_size)
    input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))

    model(input_ids, input_mask)

    onx = to_onnx(model, (input_ids, input_mask))
    print(onnx_simple_text_plot(onx))

>>>

    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:206: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
      self.fast_dtype = torch.get_autocast_cpu_dtype()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:327: DeprecationWarning: torch.is_autocast_cpu_enabled() is deprecated. Please use torch.is_autocast_enabled('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:593.)
      self.prev = torch.is_autocast_cpu_enabled()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:328: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
      self.prev_fastdtype = torch.get_autocast_cpu_dtype()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:329: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
      torch.set_autocast_cpu_enabled(self._enabled)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:330: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
      torch.set_autocast_cpu_dtype(self.fast_dtype)  # type: ignore[arg-type]
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:378: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
      torch.set_autocast_cpu_enabled(self.prev)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:379: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
      torch.set_autocast_cpu_dtype(self.prev_fastdtype)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:206: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
      self.fast_dtype = torch.get_autocast_cpu_dtype()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:327: DeprecationWarning: torch.is_autocast_cpu_enabled() is deprecated. Please use torch.is_autocast_enabled('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:593.)
      self.prev = torch.is_autocast_cpu_enabled()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:328: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
      self.prev_fastdtype = torch.get_autocast_cpu_dtype()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:329: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
      torch.set_autocast_cpu_enabled(self._enabled)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:330: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
      torch.set_autocast_cpu_dtype(self.fast_dtype)  # type: ignore[arg-type]
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:378: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
      torch.set_autocast_cpu_enabled(self.prev)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:379: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
      torch.set_autocast_cpu_dtype(self.prev_fastdtype)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:206: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
      self.fast_dtype = torch.get_autocast_cpu_dtype()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:327: DeprecationWarning: torch.is_autocast_cpu_enabled() is deprecated. Please use torch.is_autocast_enabled('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:593.)
      self.prev = torch.is_autocast_cpu_enabled()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:328: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
      self.prev_fastdtype = torch.get_autocast_cpu_dtype()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:329: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
      torch.set_autocast_cpu_enabled(self._enabled)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:330: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
      torch.set_autocast_cpu_dtype(self.fast_dtype)  # type: ignore[arg-type]
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:378: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
      torch.set_autocast_cpu_enabled(self.prev)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:379: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
      torch.set_autocast_cpu_dtype(self.prev_fastdtype)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:206: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
      self.fast_dtype = torch.get_autocast_cpu_dtype()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:327: DeprecationWarning: torch.is_autocast_cpu_enabled() is deprecated. Please use torch.is_autocast_enabled('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:593.)
      self.prev = torch.is_autocast_cpu_enabled()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:328: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
      self.prev_fastdtype = torch.get_autocast_cpu_dtype()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:329: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
      torch.set_autocast_cpu_enabled(self._enabled)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:330: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
      torch.set_autocast_cpu_dtype(self.fast_dtype)  # type: ignore[arg-type]
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:378: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
      torch.set_autocast_cpu_enabled(self.prev)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:379: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
      torch.set_autocast_cpu_dtype(self.prev_fastdtype)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:206: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
      self.fast_dtype = torch.get_autocast_cpu_dtype()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:327: DeprecationWarning: torch.is_autocast_cpu_enabled() is deprecated. Please use torch.is_autocast_enabled('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:593.)
      self.prev = torch.is_autocast_cpu_enabled()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:328: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
      self.prev_fastdtype = torch.get_autocast_cpu_dtype()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:329: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
      torch.set_autocast_cpu_enabled(self._enabled)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:330: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
      torch.set_autocast_cpu_dtype(self.fast_dtype)  # type: ignore[arg-type]
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:378: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
      torch.set_autocast_cpu_enabled(self.prev)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:379: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
      torch.set_autocast_cpu_dtype(self.prev_fastdtype)
    opset: domain='' version=18
    input: name='input_ids' type=dtype('int64') shape=[2, 1024]
    input: name='attention_mask' type=dtype('float32') shape=[2, 1024]
    init: name='p_layers_0_input_layernorm_weight' type=dtype('float32') shape=(16,)
    init: name='p_layers_0_post_attention_layernorm_weight' type=dtype('float32') shape=(16,)
    init: name='p_norm_weight' type=dtype('float32') shape=(16,)
    init: name='p_embed_tokens_weight' type=dtype('float32') shape=(1024, 16)
    init: name='p_layers_0_self_attn_q_proj_weight' type=dtype('float32') shape=(16, 16)
    init: name='p_layers_0_self_attn_k_proj_weight' type=dtype('float32') shape=(16, 16)
    init: name='p_layers_0_self_attn_v_proj_weight' type=dtype('float32') shape=(16, 16)
    init: name='p_layers_0_self_attn_o_proj_weight' type=dtype('float32') shape=(16, 16)
    init: name='p_layers_0_mlp_gate_proj_weight' type=dtype('float32') shape=(16, 16)
    init: name='p_layers_0_mlp_up_proj_weight' type=dtype('float32') shape=(16, 16)
    init: name='p_layers_0_mlp_down_proj_weight' type=dtype('float32') shape=(16, 16)
    init: name='b_layers_0_self_attn_rotary_emb_inv_freq' type=dtype('float32') shape=(4,) -- array([1.   , 0.1  , 0.01 , 0.001], dtype=float32)
    init: name='init7_s_0' type=dtype('int64') shape=() -- array([0])
    init: name='init7_s_1024' type=dtype('int64') shape=() -- array([1024])
    init: name='init7_s_1' type=dtype('int64') shape=() -- array([1])
    init: name='init7_s2_1024_1024' type=dtype('int64') shape=(2,) -- array([1024, 1024])
    init: name='init7_s2_-1_1' type=dtype('int64') shape=(2,) -- array([-1,  1])
    init: name='init7_s1_1' type=dtype('int64') shape=(1,) -- array([1])
    init: name='init7_s4_2_1_1024_1024' type=dtype('int64') shape=(4,) -- array([   2,    1, 1024, 1024])
    init: name='init1_s_' type=dtype('float32') shape=() -- array([0.], dtype=float32)
    init: name='init1_s_2' type=dtype('float32') shape=() -- array([0.], dtype=float32)
    init: name='init1_s1_' type=dtype('float32') shape=(1,) -- array([-3.403e+38], dtype=float32)
    init: name='init1_s1_2' type=dtype('float32') shape=(1,) -- array([2.], dtype=float32)
    init: name='init7_s1_-1' type=dtype('int64') shape=(1,) -- array([-1])
    init: name='init1_s_3' type=dtype('float32') shape=() -- array([1.e-06], dtype=float32)
    init: name='init7_s2_2048_16' type=dtype('int64') shape=(2,) -- array([2048,   16])
    init: name='init7_s3_2_1024_16' type=dtype('int64') shape=(3,) -- array([   2, 1024,   16])
    init: name='init7_s4_2_1024_2_8' type=dtype('int64') shape=(4,) -- array([   2, 1024,    2,    8])
    init: name='init1_s1_3' type=dtype('float32') shape=(1,) -- array([2.], dtype=float32)
    init: name='init1_s_5' type=dtype('float32') shape=() -- array([1.e-06], dtype=float32)
    init: name='init1_s1_4' type=dtype('float32') shape=(1,) -- array([2.], dtype=float32)
    init: name='init1_s_6' type=dtype('float32') shape=() -- array([1.e-06], dtype=float32)
    init: name='init7_s2_0_1' type=dtype('int64') shape=(2,) -- array([0, 1])
    init: name='init7_s2_1_2' type=dtype('int64') shape=(2,) -- array([1, 2])
    init: name='init7_s2_0_2' type=dtype('int64') shape=(2,) -- array([0, 2])
    init: name='init1_s_7' type=dtype('float32') shape=() -- array([0.354], dtype=float32)
    init: name='init7_s2_4_4' type=dtype('int64') shape=(2,) -- array([4, 4])
    ConstantOfShape(init7_s2_1024_1024, value=[-3.402823...) -> full
      Trilu(full, init7_s_1, upper=1) -> _onx_trilu0
    Gather(p_embed_tokens_weight, input_ids) -> embedding
      Pow(embedding, init1_s1_2) -> pow_1
        ReduceMean(pow_1, init7_s1_-1, keepdims=1) -> mean
          Add(mean, init1_s_3) -> add
            Sqrt(add) -> _onx_sqrt0
              Reciprocal(_onx_sqrt0) -> rsqrt
      Mul(embedding, rsqrt) -> mul_2
        Mul(p_layers_0_input_layernorm_weight, mul_2) -> mul_3
          Reshape(mul_3, init7_s2_2048_16) -> view_1
            Gemm(view_1, p_layers_0_self_attn_v_proj_weight, transA=0, transB=1) -> mm_2
              Reshape(mm_2, init7_s4_2_1024_2_8) -> view_9
                Transpose(view_9, perm=[0,2,1,3]) -> output_2
    Range(init7_s_0, init7_s_1024, init7_s_1) -> arange
      Unsqueeze(arange, init7_s2_0_1) -> unsqueeze_9
        Cast(unsqueeze_9, to=1) -> _to_copy_2
    Range(init7_s_0, init7_s_1024, init7_s_1) -> arange_1
    Reshape(arange, init7_s2_-1_1) -> view
      Greater(arange_1, view) -> gt
        Cast(gt, to=1) -> _onx_cast0
        Mul(_onx_trilu0, _onx_cast0) -> _onx_mul0
          Unsqueeze(_onx_mul0, init7_s2_0_1) -> unsqueeze_4
            Expand(unsqueeze_4, init7_s4_2_1_1024_1024) -> expand_1
              Equal(expand_1, init1_s_) -> eq
    Unsqueeze(attention_mask, init7_s2_1_2) -> unsqueeze_6
      Equal(unsqueeze_6, init1_s_2) -> eq_1
        And(eq, eq_1) -> mul_1
          Where(mul_1, init1_s1_, expand_1) -> _onx_where0
    Gemm(view_1, p_layers_0_self_attn_k_proj_weight, transA=0, transB=1) -> mm_1
      Reshape(mm_1, init7_s4_2_1024_2_8) -> view_8
        Transpose(view_8, perm=[0,2,1,3]) -> transpose_1
          Split(transpose_1, init7_s2_4_4, axis=3) -> slice_12, slice_13
            Neg(slice_13) -> neg_1
            Concat(neg_1, slice_12, axis=-1) -> cat_2
    Gemm(view_1, p_layers_0_self_attn_q_proj_weight, transA=0, transB=1) -> mm
      Reshape(mm, init7_s4_2_1024_2_8) -> view_7
        Transpose(view_7, perm=[0,2,1,3]) -> transpose
          Split(transpose, init7_s2_4_4, axis=3) -> slice_10, slice_11
            Neg(slice_11) -> neg
            Concat(neg, slice_10, axis=-1) -> cat_1
    Unsqueeze(b_layers_0_self_attn_rotary_emb_inv_freq, init7_s2_0_2) -> unsqueeze_8
      MatMul(unsqueeze_8, _to_copy_2) -> view_12
        Transpose(view_12, perm=[0,2,1]) -> transpose_3
          Concat(transpose_3, transpose_3, axis=-1) -> cat
            Cos(cat) -> cos
              Unsqueeze(cos, init7_s1_1) -> unsqueeze_10
          Mul(transpose, unsqueeze_10) -> mul_4
    Sin(cat) -> sin
      Unsqueeze(sin, init7_s1_1) -> unsqueeze_11
        Mul(cat_1, unsqueeze_11) -> mul_5
          Add(mul_4, mul_5) -> add_1
    Mul(transpose_1, unsqueeze_10) -> mul_6
    Mul(cat_2, unsqueeze_11) -> mul_7
      Add(mul_6, mul_7) -> output_1
        Transpose(output_1, perm=[0,1,3,2]) -> transpose_4
          MatMul(add_1, transpose_4) -> view_13
            Mul(view_13, init1_s_7) -> div
            Add(div, _onx_where0) -> add_3
              Softmax(add_3, axis=-1) -> _softmax
                MatMul(_softmax, output_2) -> view_15
                  Transpose(view_15, perm=[0,2,1,3]) -> transpose_5
                    Reshape(transpose_5, init7_s2_2048_16) -> view_17
                      Gemm(view_17, p_layers_0_self_attn_o_proj_weight, transA=0, transB=1) -> mm_3
                        Reshape(mm_3, init7_s3_2_1024_16) -> view_18
      Add(embedding, view_18) -> add_4
        Pow(add_4, init1_s1_3) -> pow_2
          ReduceMean(pow_2, init7_s1_-1, keepdims=1) -> mean_1
            Add(mean_1, init1_s_5) -> add_5
              Sqrt(add_5) -> _onx_sqrt02
                Reciprocal(_onx_sqrt02) -> rsqrt_1
        Mul(add_4, rsqrt_1) -> mul_8
          Mul(p_layers_0_post_attention_layernorm_weight, mul_8) -> mul_9
            Reshape(mul_9, init7_s2_2048_16) -> view_19
              Gemm(view_19, p_layers_0_mlp_up_proj_weight, transA=0, transB=1) -> mm_5
    Gemm(view_19, p_layers_0_mlp_gate_proj_weight, transA=0, transB=1) -> mm_4
      Reshape(mm_4, init7_s3_2_1024_16) -> view_20
        Sigmoid(view_20) -> _onx_sigmoid0
          Reshape(_onx_sigmoid0, init7_s2_2048_16) -> Reshape2Of3PatternR__onx_sigmoid0
      Mul(mm_4, Reshape2Of3PatternR__onx_sigmoid0) -> Reshape2Of3PatternL_silu
        Mul(Reshape2Of3PatternL_silu, mm_5) -> view_23
          Gemm(view_23, p_layers_0_mlp_down_proj_weight, transA=0, transB=1) -> mm_6
            Reshape(mm_6, init7_s3_2_1024_16) -> view_24
        Add(add_4, view_24) -> add_6
          Pow(add_6, init1_s1_4) -> pow_3
            ReduceMean(pow_3, init7_s1_-1, keepdims=1) -> mean_2
              Add(mean_2, init1_s_6) -> add_7
                Sqrt(add_7) -> _onx_sqrt03
                  Reciprocal(_onx_sqrt03) -> rsqrt_2
          Mul(add_6, rsqrt_2) -> mul_11
            Mul(p_norm_weight, mul_11) -> output_0
    output: name='output_0' type=dtype('float32') shape=[2, 1024, 16]
    output: name='output_1' type=dtype('float32') shape=[2, 2, 1024, 8]
    output: name='output_2' type=dtype('float32') shape=[2, 2, 1024, 8]

Full Example

import torch
from transformers import AutoConfig, AutoModelForCausalLM

location = "meta-llama/Llama-2-7b-hf"
cahce_dir = "_cache"
l_config = AutoConfig.from_pretrained(
    location, use_auth_token=use_auth_token, cache_dir=cache_dir
)
l_config.use_cache = True
llama = AutoModelForCausalLM.from_pretrained(
    location,
    use_auth_token=use_auth_token,
    config=l_config,
    torch_dtype=torch.float32,
    cache_dir=cache_dir=cache_dir,
)

Llama 3

See Llama3.

import os
import time
import onnxruntime
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from experimental_experiment.torch_interpreter import to_onnx

model_id = "meta-llama/Meta-Llama-3-8B"

with torch.no_grad():
    model = AutoModelForCausalLM.from_pretrained(model_id).eval()
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    base_prompt = "Is the conversion to onnx going to work?"
    base_inputs = tokenizer(base_prompt, return_tensors="pt")  # .to("cpu")
    input_ids = base_inputs.input_ids
    expected = model(input_ids)

    print(f"output type: {type(expected)}")
    print(f"logits: {expected.logits.shape}, {expected.logits.dtype}")

    print(
        "start conversion... with input_ids", input_ids.dtype, input_ids.shape
    )
    begin = time.perf_counter()
    large_onx = to_onnx(
        model,
        (input_ids,),
        input_names=["x"],
        verbose=1,
        large_model=True,
        # dynamic_shapes fails with transformers==4.37.2
        # TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool
        # dynamic_shapes={"x": {1: torch.export.Dim("length", min=2)}},
    )
    duration = time.perf_counter() - begin
    print(f"conversion done in {duration}s")

folder = "test_zoo_export_llama3"
if not os.path.exists(folder):
    os.mkdir(folder)
else:
    for _ in os.listdir(folder):
        os.remove(os.path.join(folder, _))

print(f"start saving in {folder!r}")
begin = time.perf_counter()
filename = os.path.join(folder, "llama3.onnx")
large_onx.save(filename)
duration = time.perf_counter() - begin
print(f"saving done in {duration}s with {len(os.listdir(folder))} files")

print(f"loading model {filename!r} with onnxruntime.")
begin = time.perf_counter()
sess = onnxruntime.InferenceSession(
    filename, providers=["CPUExecutionProvider"]
)
print(f"done in {time.perf_counter() - begin}s")

print("running the first iteration")
begin = time.perf_counter()
name = large_onx.model_proto.graph.input[0].name
np_input = input_ids.detach().cpu().numpy()
got = sess.run(None, {name: np_input})
print(f"done in {time.perf_counter() - begin}s")
self.assertEqualArray(expected.logits, got[0], atol=1e-4)

N = 5
print(f"running {N} iterations with torch")
begin = time.perf_counter()
for i in range(N):
    model(input_ids)
d = time.perf_counter() - begin
print(f"done in {d}s for torch")

print(f"running {N} iterations with onnxruntime")
begin = time.perf_counter()
for i in range(N):
    sess.run(None, {name: np_input})
d = time.perf_counter() - begin
print(f"done in {d}s for onnxruntime")