LLaMa¶
Dummy Example¶
<<<
import numpy as np
from experimental_experiment.helpers import pretty_onnx
import torch
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
def ids_tensor(shape, vocab_size):
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(np.random.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
config = LlamaConfig(
hidden_size=16,
num_hidden_layers=1,
vocab_size=1024,
intermediate_size=16,
max_position_embeddings=1024,
num_attention_heads=2,
)
config._attn_implementation = "eager"
with torch.no_grad():
model = LlamaModel(config)
batch, seq, vocab_size = 2, 1024, 1024
input_ids = ids_tensor([batch, seq], vocab_size)
input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))
model(input_ids, input_mask)
onx = to_onnx(
model,
(input_ids, input_mask),
export_options=ExportOptions(decomposition_table="default"),
)
print(pretty_onnx(onx))
>>>
opset: domain='' version=18
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input_ids' type=dtype('int64') shape=[2, 1024]
input: name='attention_mask' type=dtype('float32') shape=[2, 1024]
init: name='b_rotary_emb_inv_freq' type=float32 shape=(4,) -- array([1. , 0.1 , 0.01 , 0.001], dtype=float32)-- DynamoInterpret.placeholder.0
init: name='init7_s_0' type=int64 shape=() -- array([0]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s_1024' type=int64 shape=() -- array([1024]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s_1' type=int64 shape=() -- array([1]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s2_1024_1024' type=int64 shape=(2,) -- array([1024, 1024])-- Opset.make_node.1/Shape
init: name='init7_s2_-1_1' type=int64 shape=(2,) -- array([-1, 1]) -- Opset.make_node.1/Shape
init: name='init7_s1_1' type=int64 shape=(1,) -- array([1]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s4_2_1_1024_1024' type=int64 shape=(4,) -- array([ 2, 1, 1024, 1024])-- GraphBuilder.make_shape_from_results.shape
init: name='init1_s_' type=float32 shape=() -- array([0.], dtype=float32)-- shape_type_compute._cast_inputs.0
init: name='init1_s1_' type=float32 shape=(1,) -- array([-3.403e+38], dtype=float32)-- Opset.make_node.1/Small
init: name='init1_s1_2' type=float32 shape=(1,) -- array([2.], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='init7_s1_-1' type=int64 shape=(1,) -- array([-1]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_3' type=float32 shape=() -- array([1.e-06], dtype=float32)-- shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0
init: name='init7_s4_2_1024_-1_8' type=int64 shape=(4,) -- array([ 2, 1024, -1, 8])-- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_4' type=float32 shape=() -- array([2.828], dtype=float32)-- shape_type_compute._cast_inputs.0
init: name='init7_s3_2_1024_-1' type=int64 shape=(3,) -- array([ 2, 1024, -1])-- Opset.make_node.1/Shape
init: name='init7_s2_0_1' type=int64 shape=(2,) -- array([0, 1]) -- UnsqueezeUnsqueezePattern.apply.new_axis##UnsqueezeUnsqueezePattern.apply.new_axis
init: name='init7_s2_1_2' type=int64 shape=(2,) -- array([1, 2]) -- UnsqueezeUnsqueezePattern.apply.new_axis
init: name='init7_s2_0_2' type=int64 shape=(2,) -- array([0, 2]) -- UnsqueezeUnsqueezePattern.apply.new_axis
init: name='init7_s2_-1_16' type=int64 shape=(2,) -- array([-1, 16]) -- MatMulAddPattern.new_shape.1##MatMulAddPattern.new_shape.3##MatMulAddPattern.new_shape.1##MatMulAddPattern.new_shape.3
init: name='init7_s3_2_1024_-12' type=int64 shape=(3,) -- array([ 2, 1024, -1])-- MatMulAddPattern.new_shape.2##MatMulAddPattern.new_shape.2
init: name='init7_s2_4_4' type=int64 shape=(2,) -- array([4, 4]) -- SlicesSplitPattern.apply.splits##SlicesSplitPattern.apply.splits
init: name='layers.0.input_layernorm.weight' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(layers.0.input_layernorm.weight)
init: name='layers.0.post_attention_layernorm.weight' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(layers.0.post_attention_layernorm.weight)
init: name='layers.0.self_attn.q_proj.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(layers.0.self_attn.q_proj.weight)
init: name='layers.0.self_attn.k_proj.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(layers.0.self_attn.k_proj.weight)
init: name='layers.0.self_attn.v_proj.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(layers.0.self_attn.v_proj.weight)
init: name='layers.0.self_attn.o_proj.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(layers.0.self_attn.o_proj.weight)
init: name='layers.0.mlp.gate_proj.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(layers.0.mlp.gate_proj.weight)
init: name='layers.0.mlp.up_proj.weight' type=float32 shape=(16, 16) -- DynamoInterpret.placeholder.1/P(layers.0.mlp.up_proj.weight)
init: name='layers.0.mlp.down_proj.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(layers.0.mlp.down_proj.weight)
init: name='norm.weight' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(norm.weight)
init: name='embed_tokens.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embed_tokens.weight)
ConstantOfShape(init7_s2_1024_1024, value=[-3.402823...) -> full
Trilu(full, init7_s_1, upper=1) -> triu
Gather(embed_tokens.weight, input_ids) -> embedding
Pow(embedding, init1_s1_2) -> pow_1
ReduceMean(pow_1, init7_s1_-1, keepdims=1) -> mean
Range(init7_s_0, init7_s_1024, init7_s_1) -> arange
Unsqueeze(arange, init7_s2_0_1) -> unsqueeze_9
Cast(unsqueeze_9, to=1) -> _to_copy_1
Range(init7_s_0, init7_s_1024, init7_s_1) -> arange_1
Reshape(arange, init7_s2_-1_1) -> view
Greater(arange_1, view) -> gt
Cast(gt, to=1) -> _onx_cast0
Mul(triu, _onx_cast0) -> _onx_mul0
Unsqueeze(_onx_mul0, init7_s2_0_1) -> unsqueeze_4
Expand(unsqueeze_4, init7_s4_2_1_1024_1024) -> expand_1
Unsqueeze(attention_mask, init7_s2_1_2) -> unsqueeze_6
Add(expand_1, unsqueeze_6) -> add
Reshape(init1_s_, init7_s1_1) -> _onx_reshape0
Equal(add, _onx_reshape0) -> eq
Where(eq, init1_s1_, expand_1) -> masked_fill
Unsqueeze(b_rotary_emb_inv_freq, init7_s2_0_2) -> unsqueeze_8
MatMul(unsqueeze_8, _to_copy_1) -> matmul
Transpose(matmul, perm=[0,2,1]) -> transpose
Concat(transpose, transpose, axis=-1) -> cat
Cos(cat) -> cos
Unsqueeze(cos, init7_s1_1) -> unsqueeze_10
Sin(cat) -> sin
Unsqueeze(sin, init7_s1_1) -> unsqueeze_11
Reshape(init1_s_3, init7_s1_1) -> _onx_reshape04
Add(mean, _onx_reshape04) -> add_1
Sqrt(add_1) -> _onx_sqrt0
Reciprocal(_onx_sqrt0) -> rsqrt
Mul(embedding, rsqrt) -> mul_3
Mul(layers.0.input_layernorm.weight, mul_3) -> mul_4
Transpose(layers.0.self_attn.q_proj.weight, perm=[1,0]) -> _onx_transpose0
MatMul(mul_4, _onx_transpose0) -> linear
Reshape(linear, init7_s4_2_1024_-1_8) -> view_1
Transpose(view_1, perm=[0,2,1,3]) -> transpose_1
Split(transpose_1, init7_s2_4_4, axis=3) -> slice_24, slice_25
Neg(slice_25) -> neg
Concat(neg, slice_24, axis=-1) -> cat_1
Mul(cat_1, unsqueeze_11) -> mul_6
Transpose(layers.0.self_attn.k_proj.weight, perm=[1,0]) -> _onx_transpose02
MatMul(mul_4, _onx_transpose02) -> linear_1
Reshape(linear_1, init7_s4_2_1024_-1_8) -> view_2
Transpose(view_2, perm=[0,2,1,3]) -> transpose_2
Split(transpose_2, init7_s2_4_4, axis=3) -> slice_26, slice_27
Neg(slice_27) -> neg_1
Concat(neg_1, slice_26, axis=-1) -> cat_2
Mul(cat_2, unsqueeze_11) -> mul_8
Transpose(layers.0.self_attn.v_proj.weight, perm=[1,0]) -> _onx_transpose03
MatMul(mul_4, _onx_transpose03) -> linear_2
Reshape(linear_2, init7_s4_2_1024_-1_8) -> view_3
Transpose(view_3, perm=[0,2,1,3]) -> output_2
Mul(transpose_1, unsqueeze_10) -> mul_5
Add(mul_5, mul_6) -> add_2
Mul(transpose_2, unsqueeze_10) -> mul_7
Add(mul_7, mul_8) -> output_1
Transpose(output_1, perm=[0,1,3,2]) -> transpose_4
MatMul(add_2, transpose_4) -> matmul_1
Reshape(init1_s_4, init7_s1_1) -> _onx_reshape05
Div(matmul_1, _onx_reshape05) -> div
Add(div, masked_fill) -> add_4
Softmax(add_4, axis=-1) -> softmax
MatMul(softmax, output_2) -> matmul_2
Transpose(matmul_2, perm=[0,2,1,3]) -> transpose_5
Reshape(transpose_5, init7_s3_2_1024_-1) -> view_4
Reshape(view_4, init7_s2_-1_16) -> MatMulAddPattern--view_4
Reshape(embedding, init7_s2_-1_16) -> MatMulAddPattern--view_42
Gemm(MatMulAddPattern--view_4, layers.0.self_attn.o_proj.weight, MatMulAddPattern--view_42, transB=1) -> MatMulAddPattern--view_43
Reshape(MatMulAddPattern--view_43, init7_s3_2_1024_-12) -> add_5
Pow(add_5, init1_s1_2) -> pow_2
ReduceMean(pow_2, init7_s1_-1, keepdims=1) -> mean_1
Reshape(init1_s_3, init7_s1_1) -> _onx_reshape06
Add(mean_1, _onx_reshape06) -> add_6
Sqrt(add_6) -> _onx_sqrt02
Reciprocal(_onx_sqrt02) -> rsqrt_1
Mul(add_5, rsqrt_1) -> mul_9
Mul(layers.0.post_attention_layernorm.weight, mul_9) -> mul_10
Transpose(layers.0.mlp.gate_proj.weight, perm=[1,0]) -> _onx_transpose05
MatMul(mul_10, _onx_transpose05) -> linear_4
Sigmoid(linear_4) -> _onx_sigmoid0
Mul(linear_4, _onx_sigmoid0) -> silu
Transpose(layers.0.mlp.up_proj.weight, perm=[1,0]) -> _onx_transpose06
MatMul(mul_10, _onx_transpose06) -> linear_5
Mul(silu, linear_5) -> mul_11
Reshape(mul_11, init7_s2_-1_16) -> MatMulAddPattern--mul_11
Reshape(add_5, init7_s2_-1_16) -> MatMulAddPattern--mul_112
Gemm(MatMulAddPattern--mul_11, layers.0.mlp.down_proj.weight, MatMulAddPattern--mul_112, transB=1) -> MatMulAddPattern--mul_113
Reshape(MatMulAddPattern--mul_113, init7_s3_2_1024_-12) -> add_7
Pow(add_7, init1_s1_2) -> pow_3
ReduceMean(pow_3, init7_s1_-1, keepdims=1) -> mean_2
Reshape(init1_s_3, init7_s1_1) -> _onx_reshape07
Add(mean_2, _onx_reshape07) -> add_8
Sqrt(add_8) -> _onx_sqrt03
Reciprocal(_onx_sqrt03) -> rsqrt_2
Mul(add_7, rsqrt_2) -> mul_12
Mul(norm.weight, mul_12) -> output_0
output: name='output_0' type=dtype('float32') shape=[2, 1024, 16]
output: name='output_1' type=dtype('float32') shape=[2, 2, 1024, 8]
output: name='output_2' type=dtype('float32') shape=[2, 2, 1024, 8]
Full Example¶
import torch
from transformers import AutoConfig, AutoModelForCausalLM
location = "meta-llama/Llama-2-7b-hf"
cahce_dir = "_cache"
l_config = AutoConfig.from_pretrained(
location, use_auth_token=use_auth_token, cache_dir=cache_dir
)
l_config.use_cache = True
llama = AutoModelForCausalLM.from_pretrained(
location,
use_auth_token=use_auth_token,
config=l_config,
torch_dtype=torch.float32,
cache_dir=cache_dir=cache_dir,
)
Llama 3¶
See Llama3.
import os
import time
import onnxruntime
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from experimental_experiment.torch_interpreter import to_onnx
model_id = "meta-llama/Meta-Llama-3-8B"
with torch.no_grad():
model = AutoModelForCausalLM.from_pretrained(model_id).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)
base_prompt = "Is the conversion to onnx going to work?"
base_inputs = tokenizer(base_prompt, return_tensors="pt") # .to("cpu")
input_ids = base_inputs.input_ids
expected = model(input_ids)
print(f"output type: {type(expected)}")
print(f"logits: {expected.logits.shape}, {expected.logits.dtype}")
print(
"start conversion... with input_ids", input_ids.dtype, input_ids.shape
)
begin = time.perf_counter()
large_onx = to_onnx(
model,
(input_ids,),
input_names=["x"],
verbose=1,
large_model=True,
# dynamic_shapes fails with transformers==4.37.2
# TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool
# dynamic_shapes={"x": {1: torch.export.Dim("length", min=2)}},
)
duration = time.perf_counter() - begin
print(f"conversion done in {duration}s")
folder = "test_zoo_export_llama3"
if not os.path.exists(folder):
os.mkdir(folder)
else:
for _ in os.listdir(folder):
os.remove(os.path.join(folder, _))
print(f"start saving in {folder!r}")
begin = time.perf_counter()
filename = os.path.join(folder, "llama3.onnx")
large_onx.save(filename)
duration = time.perf_counter() - begin
print(f"saving done in {duration}s with {len(os.listdir(folder))} files")
print(f"loading model {filename!r} with onnxruntime.")
begin = time.perf_counter()
sess = onnxruntime.InferenceSession(
filename, providers=["CPUExecutionProvider"]
)
print(f"done in {time.perf_counter() - begin}s")
print("running the first iteration")
begin = time.perf_counter()
name = large_onx.model_proto.graph.input[0].name
np_input = input_ids.detach().cpu().numpy()
got = sess.run(None, {name: np_input})
print(f"done in {time.perf_counter() - begin}s")
self.assertEqualArray(expected.logits, got[0], atol=1e-4)
N = 5
print(f"running {N} iterations with torch")
begin = time.perf_counter()
for i in range(N):
model(input_ids)
d = time.perf_counter() - begin
print(f"done in {d}s for torch")
print(f"running {N} iterations with onnxruntime")
begin = time.perf_counter()
for i in range(N):
sess.run(None, {name: np_input})
d = time.perf_counter() - begin
print(f"done in {d}s for onnxruntime")