LLaMa

Dummy Example

LLaMa

<<<

import numpy as np
from experimental_experiment.helpers import pretty_onnx
import torch
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
from experimental_experiment.torch_interpreter import to_onnx, ExportOptions


def ids_tensor(shape, vocab_size):
    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(np.random.randint(0, vocab_size - 1))

    return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()


config = LlamaConfig(
    hidden_size=16,
    num_hidden_layers=1,
    vocab_size=1024,
    intermediate_size=16,
    max_position_embeddings=1024,
    num_attention_heads=2,
)
config._attn_implementation = "eager"

with torch.no_grad():

    model = LlamaModel(config)

    batch, seq, vocab_size = 2, 1024, 1024

    input_ids = ids_tensor([batch, seq], vocab_size)
    input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))

    model(input_ids, input_mask)

    onx = to_onnx(
        model,
        (input_ids, input_mask),
        export_options=ExportOptions(decomposition_table="default"),
    )
    print(pretty_onnx(onx))

>>>

    opset: domain='' version=18
    doc_string: large_model=False, inline=False, external_threshold=102...
    input: name='input_ids' type=dtype('int64') shape=[2, 1024]
    input: name='attention_mask' type=dtype('float32') shape=[2, 1024]
    init: name='b_rotary_emb_inv_freq' type=float32 shape=(4,) -- array([1.   , 0.1  , 0.01 , 0.001], dtype=float32)-- DynamoInterpret.placeholder.0
    init: name='init7_s_0' type=int64 shape=() -- array([0])              -- Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init7_s_1024' type=int64 shape=() -- array([1024])        -- Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init7_s_1' type=int64 shape=() -- array([1])              -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init7_s2_1024_1024' type=int64 shape=(2,) -- array([1024, 1024])-- Opset.make_node.1/Shape
    init: name='init7_s2_-1_1' type=int64 shape=(2,) -- array([-1,  1])   -- Opset.make_node.1/Shape
    init: name='init7_s1_1' type=int64 shape=(1,) -- array([1])           -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init7_s4_2_1_1024_1024' type=int64 shape=(4,) -- array([   2,    1, 1024, 1024])-- GraphBuilder.make_shape_from_results.shape
    init: name='init1_s_' type=float32 shape=() -- array([0.], dtype=float32)-- shape_type_compute._cast_inputs.0
    init: name='init1_s1_' type=float32 shape=(1,) -- array([-3.403e+38], dtype=float32)-- Opset.make_node.1/Small
    init: name='init1_s1_2' type=float32 shape=(1,) -- array([2.], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small##Opset.make_node.1/Small
    init: name='init7_s1_-1' type=int64 shape=(1,) -- array([-1])         -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init1_s_3' type=float32 shape=() -- array([1.e-06], dtype=float32)-- shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0
    init: name='init7_s4_2_1024_-1_8' type=int64 shape=(4,) -- array([   2, 1024,   -1,    8])-- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init1_s_4' type=float32 shape=() -- array([2.828], dtype=float32)-- shape_type_compute._cast_inputs.0
    init: name='init7_s3_2_1024_-1' type=int64 shape=(3,) -- array([   2, 1024,   -1])-- Opset.make_node.1/Shape
    init: name='init7_s2_0_1' type=int64 shape=(2,) -- array([0, 1])      -- UnsqueezeUnsqueezePattern.apply.new_axis##UnsqueezeUnsqueezePattern.apply.new_axis
    init: name='init7_s2_1_2' type=int64 shape=(2,) -- array([1, 2])      -- UnsqueezeUnsqueezePattern.apply.new_axis
    init: name='init7_s2_0_2' type=int64 shape=(2,) -- array([0, 2])      -- UnsqueezeUnsqueezePattern.apply.new_axis
    init: name='init7_s2_-1_16' type=int64 shape=(2,) -- array([-1, 16])  -- MatMulAddPattern.new_shape.1##MatMulAddPattern.new_shape.3##MatMulAddPattern.new_shape.1##MatMulAddPattern.new_shape.3
    init: name='init7_s3_2_1024_-12' type=int64 shape=(3,) -- array([   2, 1024,   -1])-- MatMulAddPattern.new_shape.2##MatMulAddPattern.new_shape.2
    init: name='init7_s2_4_4' type=int64 shape=(2,) -- array([4, 4])      -- SlicesSplitPattern.apply.splits##SlicesSplitPattern.apply.splits
    init: name='layers.0.input_layernorm.weight' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(layers.0.input_layernorm.weight)
    init: name='layers.0.post_attention_layernorm.weight' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(layers.0.post_attention_layernorm.weight)
    init: name='layers.0.self_attn.q_proj.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(layers.0.self_attn.q_proj.weight)
    init: name='layers.0.self_attn.k_proj.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(layers.0.self_attn.k_proj.weight)
    init: name='layers.0.self_attn.v_proj.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(layers.0.self_attn.v_proj.weight)
    init: name='layers.0.self_attn.o_proj.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(layers.0.self_attn.o_proj.weight)
    init: name='layers.0.mlp.gate_proj.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(layers.0.mlp.gate_proj.weight)
    init: name='layers.0.mlp.up_proj.weight' type=float32 shape=(16, 16)  -- DynamoInterpret.placeholder.1/P(layers.0.mlp.up_proj.weight)
    init: name='layers.0.mlp.down_proj.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(layers.0.mlp.down_proj.weight)
    init: name='norm.weight' type=float32 shape=(16,)                     -- DynamoInterpret.placeholder.1/P(norm.weight)
    init: name='embed_tokens.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embed_tokens.weight)
    ConstantOfShape(init7_s2_1024_1024, value=[-3.402823...) -> full
      Trilu(full, init7_s_1, upper=1) -> triu
    Gather(embed_tokens.weight, input_ids) -> embedding
      Pow(embedding, init1_s1_2) -> pow_1
        ReduceMean(pow_1, init7_s1_-1, keepdims=1) -> mean
    Range(init7_s_0, init7_s_1024, init7_s_1) -> arange
      Unsqueeze(arange, init7_s2_0_1) -> unsqueeze_9
        Cast(unsqueeze_9, to=1) -> _to_copy_1
    Range(init7_s_0, init7_s_1024, init7_s_1) -> arange_1
    Reshape(arange, init7_s2_-1_1) -> view
      Greater(arange_1, view) -> gt
        Cast(gt, to=1) -> _onx_cast0
        Mul(triu, _onx_cast0) -> _onx_mul0
          Unsqueeze(_onx_mul0, init7_s2_0_1) -> unsqueeze_4
            Expand(unsqueeze_4, init7_s4_2_1_1024_1024) -> expand_1
    Unsqueeze(attention_mask, init7_s2_1_2) -> unsqueeze_6
      Add(expand_1, unsqueeze_6) -> add
    Reshape(init1_s_, init7_s1_1) -> _onx_reshape0
      Equal(add, _onx_reshape0) -> eq
        Where(eq, init1_s1_, expand_1) -> masked_fill
    Unsqueeze(b_rotary_emb_inv_freq, init7_s2_0_2) -> unsqueeze_8
      MatMul(unsqueeze_8, _to_copy_1) -> matmul
        Transpose(matmul, perm=[0,2,1]) -> transpose
          Concat(transpose, transpose, axis=-1) -> cat
            Cos(cat) -> cos
              Unsqueeze(cos, init7_s1_1) -> unsqueeze_10
            Sin(cat) -> sin
              Unsqueeze(sin, init7_s1_1) -> unsqueeze_11
    Reshape(init1_s_3, init7_s1_1) -> _onx_reshape04
      Add(mean, _onx_reshape04) -> add_1
        Sqrt(add_1) -> _onx_sqrt0
          Reciprocal(_onx_sqrt0) -> rsqrt
      Mul(embedding, rsqrt) -> mul_3
        Mul(layers.0.input_layernorm.weight, mul_3) -> mul_4
    Transpose(layers.0.self_attn.q_proj.weight, perm=[1,0]) -> _onx_transpose0
      MatMul(mul_4, _onx_transpose0) -> linear
        Reshape(linear, init7_s4_2_1024_-1_8) -> view_1
          Transpose(view_1, perm=[0,2,1,3]) -> transpose_1
            Split(transpose_1, init7_s2_4_4, axis=3) -> slice_24, slice_25
              Neg(slice_25) -> neg
              Concat(neg, slice_24, axis=-1) -> cat_1
                Mul(cat_1, unsqueeze_11) -> mul_6
    Transpose(layers.0.self_attn.k_proj.weight, perm=[1,0]) -> _onx_transpose02
      MatMul(mul_4, _onx_transpose02) -> linear_1
        Reshape(linear_1, init7_s4_2_1024_-1_8) -> view_2
          Transpose(view_2, perm=[0,2,1,3]) -> transpose_2
            Split(transpose_2, init7_s2_4_4, axis=3) -> slice_26, slice_27
              Neg(slice_27) -> neg_1
              Concat(neg_1, slice_26, axis=-1) -> cat_2
                Mul(cat_2, unsqueeze_11) -> mul_8
    Transpose(layers.0.self_attn.v_proj.weight, perm=[1,0]) -> _onx_transpose03
      MatMul(mul_4, _onx_transpose03) -> linear_2
        Reshape(linear_2, init7_s4_2_1024_-1_8) -> view_3
          Transpose(view_3, perm=[0,2,1,3]) -> output_2
    Mul(transpose_1, unsqueeze_10) -> mul_5
      Add(mul_5, mul_6) -> add_2
    Mul(transpose_2, unsqueeze_10) -> mul_7
      Add(mul_7, mul_8) -> output_1
        Transpose(output_1, perm=[0,1,3,2]) -> transpose_4
        MatMul(add_2, transpose_4) -> matmul_1
    Reshape(init1_s_4, init7_s1_1) -> _onx_reshape05
      Div(matmul_1, _onx_reshape05) -> div
        Add(div, masked_fill) -> add_4
          Softmax(add_4, axis=-1) -> softmax
            MatMul(softmax, output_2) -> matmul_2
              Transpose(matmul_2, perm=[0,2,1,3]) -> transpose_5
                Reshape(transpose_5, init7_s3_2_1024_-1) -> view_4
                  Reshape(view_4, init7_s2_-1_16) -> MatMulAddPattern--view_4
      Reshape(embedding, init7_s2_-1_16) -> MatMulAddPattern--view_42
        Gemm(MatMulAddPattern--view_4, layers.0.self_attn.o_proj.weight, MatMulAddPattern--view_42, transB=1) -> MatMulAddPattern--view_43
          Reshape(MatMulAddPattern--view_43, init7_s3_2_1024_-12) -> add_5
            Pow(add_5, init1_s1_2) -> pow_2
              ReduceMean(pow_2, init7_s1_-1, keepdims=1) -> mean_1
    Reshape(init1_s_3, init7_s1_1) -> _onx_reshape06
      Add(mean_1, _onx_reshape06) -> add_6
        Sqrt(add_6) -> _onx_sqrt02
          Reciprocal(_onx_sqrt02) -> rsqrt_1
            Mul(add_5, rsqrt_1) -> mul_9
              Mul(layers.0.post_attention_layernorm.weight, mul_9) -> mul_10
    Transpose(layers.0.mlp.gate_proj.weight, perm=[1,0]) -> _onx_transpose05
      MatMul(mul_10, _onx_transpose05) -> linear_4
        Sigmoid(linear_4) -> _onx_sigmoid0
        Mul(linear_4, _onx_sigmoid0) -> silu
    Transpose(layers.0.mlp.up_proj.weight, perm=[1,0]) -> _onx_transpose06
      MatMul(mul_10, _onx_transpose06) -> linear_5
        Mul(silu, linear_5) -> mul_11
          Reshape(mul_11, init7_s2_-1_16) -> MatMulAddPattern--mul_11
    Reshape(add_5, init7_s2_-1_16) -> MatMulAddPattern--mul_112
      Gemm(MatMulAddPattern--mul_11, layers.0.mlp.down_proj.weight, MatMulAddPattern--mul_112, transB=1) -> MatMulAddPattern--mul_113
        Reshape(MatMulAddPattern--mul_113, init7_s3_2_1024_-12) -> add_7
          Pow(add_7, init1_s1_2) -> pow_3
            ReduceMean(pow_3, init7_s1_-1, keepdims=1) -> mean_2
    Reshape(init1_s_3, init7_s1_1) -> _onx_reshape07
      Add(mean_2, _onx_reshape07) -> add_8
        Sqrt(add_8) -> _onx_sqrt03
          Reciprocal(_onx_sqrt03) -> rsqrt_2
          Mul(add_7, rsqrt_2) -> mul_12
            Mul(norm.weight, mul_12) -> output_0
    output: name='output_0' type=dtype('float32') shape=[2, 1024, 16]
    output: name='output_1' type=dtype('float32') shape=[2, 2, 1024, 8]
    output: name='output_2' type=dtype('float32') shape=[2, 2, 1024, 8]

Full Example

import torch
from transformers import AutoConfig, AutoModelForCausalLM

location = "meta-llama/Llama-2-7b-hf"
cahce_dir = "_cache"
l_config = AutoConfig.from_pretrained(
    location, use_auth_token=use_auth_token, cache_dir=cache_dir
)
l_config.use_cache = True
llama = AutoModelForCausalLM.from_pretrained(
    location,
    use_auth_token=use_auth_token,
    config=l_config,
    torch_dtype=torch.float32,
    cache_dir=cache_dir=cache_dir,
)

Llama 3

See Llama3.

import os
import time
import onnxruntime
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from experimental_experiment.torch_interpreter import to_onnx

model_id = "meta-llama/Meta-Llama-3-8B"

with torch.no_grad():
    model = AutoModelForCausalLM.from_pretrained(model_id).eval()
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    base_prompt = "Is the conversion to onnx going to work?"
    base_inputs = tokenizer(base_prompt, return_tensors="pt")  # .to("cpu")
    input_ids = base_inputs.input_ids
    expected = model(input_ids)

    print(f"output type: {type(expected)}")
    print(f"logits: {expected.logits.shape}, {expected.logits.dtype}")

    print(
        "start conversion... with input_ids", input_ids.dtype, input_ids.shape
    )
    begin = time.perf_counter()
    large_onx = to_onnx(
        model,
        (input_ids,),
        input_names=["x"],
        verbose=1,
        large_model=True,
        # dynamic_shapes fails with transformers==4.37.2
        # TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool
        # dynamic_shapes={"x": {1: torch.export.Dim("length", min=2)}},
    )
    duration = time.perf_counter() - begin
    print(f"conversion done in {duration}s")

folder = "test_zoo_export_llama3"
if not os.path.exists(folder):
    os.mkdir(folder)
else:
    for _ in os.listdir(folder):
        os.remove(os.path.join(folder, _))

print(f"start saving in {folder!r}")
begin = time.perf_counter()
filename = os.path.join(folder, "llama3.onnx")
large_onx.save(filename)
duration = time.perf_counter() - begin
print(f"saving done in {duration}s with {len(os.listdir(folder))} files")

print(f"loading model {filename!r} with onnxruntime.")
begin = time.perf_counter()
sess = onnxruntime.InferenceSession(
    filename, providers=["CPUExecutionProvider"]
)
print(f"done in {time.perf_counter() - begin}s")

print("running the first iteration")
begin = time.perf_counter()
name = large_onx.model_proto.graph.input[0].name
np_input = input_ids.detach().cpu().numpy()
got = sess.run(None, {name: np_input})
print(f"done in {time.perf_counter() - begin}s")
self.assertEqualArray(expected.logits, got[0], atol=1e-4)

N = 5
print(f"running {N} iterations with torch")
begin = time.perf_counter()
for i in range(N):
    model(input_ids)
d = time.perf_counter() - begin
print(f"done in {d}s for torch")

print(f"running {N} iterations with onnxruntime")
begin = time.perf_counter()
for i in range(N):
    sess.run(None, {name: np_input})
d = time.perf_counter() - begin
print(f"done in {d}s for onnxruntime")