LLaMa¶
Dummy Example¶
<<<
import numpy as np
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
import torch
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
from experimental_experiment.torch_interpreter import to_onnx
def ids_tensor(shape, vocab_size):
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(np.random.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
config = LlamaConfig(
hidden_size=16,
num_hidden_layers=1,
vocab_size=1024,
intermediate_size=16,
max_position_embeddings=1024,
num_attention_heads=2,
)
config._attn_implementation = "eager"
with torch.no_grad():
model = LlamaModel(config)
batch, seq, vocab_size = 2, 1024, 1024
input_ids = ids_tensor([batch, seq], vocab_size)
input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))
model(input_ids, input_mask)
onx = to_onnx(model, (input_ids, input_mask))
print(onnx_simple_text_plot(onx))
>>>
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:206: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
self.fast_dtype = torch.get_autocast_cpu_dtype()
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:327: DeprecationWarning: torch.is_autocast_cpu_enabled() is deprecated. Please use torch.is_autocast_enabled('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:593.)
self.prev = torch.is_autocast_cpu_enabled()
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:328: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
self.prev_fastdtype = torch.get_autocast_cpu_dtype()
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:329: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
torch.set_autocast_cpu_enabled(self._enabled)
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:330: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
torch.set_autocast_cpu_dtype(self.fast_dtype) # type: ignore[arg-type]
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:378: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
torch.set_autocast_cpu_enabled(self.prev)
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:379: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
torch.set_autocast_cpu_dtype(self.prev_fastdtype)
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:206: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
self.fast_dtype = torch.get_autocast_cpu_dtype()
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:327: DeprecationWarning: torch.is_autocast_cpu_enabled() is deprecated. Please use torch.is_autocast_enabled('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:593.)
self.prev = torch.is_autocast_cpu_enabled()
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:328: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
self.prev_fastdtype = torch.get_autocast_cpu_dtype()
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:329: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
torch.set_autocast_cpu_enabled(self._enabled)
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:330: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
torch.set_autocast_cpu_dtype(self.fast_dtype) # type: ignore[arg-type]
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:378: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
torch.set_autocast_cpu_enabled(self.prev)
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:379: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
torch.set_autocast_cpu_dtype(self.prev_fastdtype)
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:206: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
self.fast_dtype = torch.get_autocast_cpu_dtype()
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:327: DeprecationWarning: torch.is_autocast_cpu_enabled() is deprecated. Please use torch.is_autocast_enabled('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:593.)
self.prev = torch.is_autocast_cpu_enabled()
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:328: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
self.prev_fastdtype = torch.get_autocast_cpu_dtype()
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:329: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
torch.set_autocast_cpu_enabled(self._enabled)
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:330: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
torch.set_autocast_cpu_dtype(self.fast_dtype) # type: ignore[arg-type]
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:378: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
torch.set_autocast_cpu_enabled(self.prev)
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:379: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
torch.set_autocast_cpu_dtype(self.prev_fastdtype)
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:206: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
self.fast_dtype = torch.get_autocast_cpu_dtype()
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:327: DeprecationWarning: torch.is_autocast_cpu_enabled() is deprecated. Please use torch.is_autocast_enabled('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:593.)
self.prev = torch.is_autocast_cpu_enabled()
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:328: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
self.prev_fastdtype = torch.get_autocast_cpu_dtype()
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:329: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
torch.set_autocast_cpu_enabled(self._enabled)
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:330: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
torch.set_autocast_cpu_dtype(self.fast_dtype) # type: ignore[arg-type]
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:378: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
torch.set_autocast_cpu_enabled(self.prev)
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:379: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
torch.set_autocast_cpu_dtype(self.prev_fastdtype)
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:206: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
self.fast_dtype = torch.get_autocast_cpu_dtype()
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:327: DeprecationWarning: torch.is_autocast_cpu_enabled() is deprecated. Please use torch.is_autocast_enabled('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:593.)
self.prev = torch.is_autocast_cpu_enabled()
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:328: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
self.prev_fastdtype = torch.get_autocast_cpu_dtype()
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:329: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
torch.set_autocast_cpu_enabled(self._enabled)
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:330: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
torch.set_autocast_cpu_dtype(self.fast_dtype) # type: ignore[arg-type]
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:378: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
torch.set_autocast_cpu_enabled(self.prev)
/home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:379: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
torch.set_autocast_cpu_dtype(self.prev_fastdtype)
opset: domain='' version=18
input: name='input_ids' type=dtype('int64') shape=[2, 1024]
input: name='attention_mask' type=dtype('float32') shape=[2, 1024]
init: name='p_layers_0_input_layernorm_weight' type=dtype('float32') shape=(16,)
init: name='p_layers_0_post_attention_layernorm_weight' type=dtype('float32') shape=(16,)
init: name='p_norm_weight' type=dtype('float32') shape=(16,)
init: name='p_embed_tokens_weight' type=dtype('float32') shape=(1024, 16)
init: name='p_layers_0_self_attn_q_proj_weight' type=dtype('float32') shape=(16, 16)
init: name='p_layers_0_self_attn_k_proj_weight' type=dtype('float32') shape=(16, 16)
init: name='p_layers_0_self_attn_v_proj_weight' type=dtype('float32') shape=(16, 16)
init: name='p_layers_0_self_attn_o_proj_weight' type=dtype('float32') shape=(16, 16)
init: name='p_layers_0_mlp_gate_proj_weight' type=dtype('float32') shape=(16, 16)
init: name='p_layers_0_mlp_up_proj_weight' type=dtype('float32') shape=(16, 16)
init: name='p_layers_0_mlp_down_proj_weight' type=dtype('float32') shape=(16, 16)
init: name='b_layers_0_self_attn_rotary_emb_inv_freq' type=dtype('float32') shape=(4,) -- array([1. , 0.1 , 0.01 , 0.001], dtype=float32)
init: name='init7_s_0' type=dtype('int64') shape=() -- array([0])
init: name='init7_s_1024' type=dtype('int64') shape=() -- array([1024])
init: name='init7_s_1' type=dtype('int64') shape=() -- array([1])
init: name='init7_s2_1024_1024' type=dtype('int64') shape=(2,) -- array([1024, 1024])
init: name='init7_s2_-1_1' type=dtype('int64') shape=(2,) -- array([-1, 1])
init: name='init7_s1_1' type=dtype('int64') shape=(1,) -- array([1])
init: name='init7_s4_2_1_1024_1024' type=dtype('int64') shape=(4,) -- array([ 2, 1, 1024, 1024])
init: name='init1_s_' type=dtype('float32') shape=() -- array([0.], dtype=float32)
init: name='init1_s_2' type=dtype('float32') shape=() -- array([0.], dtype=float32)
init: name='init1_s1_' type=dtype('float32') shape=(1,) -- array([-3.403e+38], dtype=float32)
init: name='init1_s1_2' type=dtype('float32') shape=(1,) -- array([2.], dtype=float32)
init: name='init7_s1_-1' type=dtype('int64') shape=(1,) -- array([-1])
init: name='init1_s_3' type=dtype('float32') shape=() -- array([1.e-06], dtype=float32)
init: name='init7_s2_2048_16' type=dtype('int64') shape=(2,) -- array([2048, 16])
init: name='init7_s3_2_1024_16' type=dtype('int64') shape=(3,) -- array([ 2, 1024, 16])
init: name='init7_s4_2_1024_2_8' type=dtype('int64') shape=(4,) -- array([ 2, 1024, 2, 8])
init: name='init1_s1_3' type=dtype('float32') shape=(1,) -- array([2.], dtype=float32)
init: name='init1_s_5' type=dtype('float32') shape=() -- array([1.e-06], dtype=float32)
init: name='init1_s1_4' type=dtype('float32') shape=(1,) -- array([2.], dtype=float32)
init: name='init1_s_6' type=dtype('float32') shape=() -- array([1.e-06], dtype=float32)
init: name='init7_s2_0_1' type=dtype('int64') shape=(2,) -- array([0, 1])
init: name='init7_s2_1_2' type=dtype('int64') shape=(2,) -- array([1, 2])
init: name='init7_s2_0_2' type=dtype('int64') shape=(2,) -- array([0, 2])
init: name='init1_s_7' type=dtype('float32') shape=() -- array([0.354], dtype=float32)
init: name='init7_s2_4_4' type=dtype('int64') shape=(2,) -- array([4, 4])
ConstantOfShape(init7_s2_1024_1024, value=[-3.402823...) -> full
Trilu(full, init7_s_1, upper=1) -> _onx_trilu0
Gather(p_embed_tokens_weight, input_ids) -> embedding
Pow(embedding, init1_s1_2) -> pow_1
ReduceMean(pow_1, init7_s1_-1, keepdims=1) -> mean
Add(mean, init1_s_3) -> add
Sqrt(add) -> _onx_sqrt0
Reciprocal(_onx_sqrt0) -> rsqrt
Mul(embedding, rsqrt) -> mul_2
Mul(p_layers_0_input_layernorm_weight, mul_2) -> mul_3
Reshape(mul_3, init7_s2_2048_16) -> view_1
Gemm(view_1, p_layers_0_self_attn_v_proj_weight, transA=0, transB=1) -> mm_2
Reshape(mm_2, init7_s4_2_1024_2_8) -> view_9
Transpose(view_9, perm=[0,2,1,3]) -> output_2
Range(init7_s_0, init7_s_1024, init7_s_1) -> arange
Unsqueeze(arange, init7_s2_0_1) -> unsqueeze_9
Cast(unsqueeze_9, to=1) -> _to_copy_2
Range(init7_s_0, init7_s_1024, init7_s_1) -> arange_1
Reshape(arange, init7_s2_-1_1) -> view
Greater(arange_1, view) -> gt
Cast(gt, to=1) -> _onx_cast0
Mul(_onx_trilu0, _onx_cast0) -> _onx_mul0
Unsqueeze(_onx_mul0, init7_s2_0_1) -> unsqueeze_4
Expand(unsqueeze_4, init7_s4_2_1_1024_1024) -> expand_1
Equal(expand_1, init1_s_) -> eq
Unsqueeze(attention_mask, init7_s2_1_2) -> unsqueeze_6
Equal(unsqueeze_6, init1_s_2) -> eq_1
And(eq, eq_1) -> mul_1
Where(mul_1, init1_s1_, expand_1) -> _onx_where0
Gemm(view_1, p_layers_0_self_attn_k_proj_weight, transA=0, transB=1) -> mm_1
Reshape(mm_1, init7_s4_2_1024_2_8) -> view_8
Transpose(view_8, perm=[0,2,1,3]) -> transpose_1
Split(transpose_1, init7_s2_4_4, axis=3) -> slice_12, slice_13
Neg(slice_13) -> neg_1
Concat(neg_1, slice_12, axis=-1) -> cat_2
Gemm(view_1, p_layers_0_self_attn_q_proj_weight, transA=0, transB=1) -> mm
Reshape(mm, init7_s4_2_1024_2_8) -> view_7
Transpose(view_7, perm=[0,2,1,3]) -> transpose
Split(transpose, init7_s2_4_4, axis=3) -> slice_10, slice_11
Neg(slice_11) -> neg
Concat(neg, slice_10, axis=-1) -> cat_1
Unsqueeze(b_layers_0_self_attn_rotary_emb_inv_freq, init7_s2_0_2) -> unsqueeze_8
MatMul(unsqueeze_8, _to_copy_2) -> view_12
Transpose(view_12, perm=[0,2,1]) -> transpose_3
Concat(transpose_3, transpose_3, axis=-1) -> cat
Cos(cat) -> cos
Unsqueeze(cos, init7_s1_1) -> unsqueeze_10
Mul(transpose, unsqueeze_10) -> mul_4
Sin(cat) -> sin
Unsqueeze(sin, init7_s1_1) -> unsqueeze_11
Mul(cat_1, unsqueeze_11) -> mul_5
Add(mul_4, mul_5) -> add_1
Mul(transpose_1, unsqueeze_10) -> mul_6
Mul(cat_2, unsqueeze_11) -> mul_7
Add(mul_6, mul_7) -> output_1
Transpose(output_1, perm=[0,1,3,2]) -> transpose_4
MatMul(add_1, transpose_4) -> view_13
Mul(view_13, init1_s_7) -> div
Add(div, _onx_where0) -> add_3
Softmax(add_3, axis=-1) -> _softmax
MatMul(_softmax, output_2) -> view_15
Transpose(view_15, perm=[0,2,1,3]) -> transpose_5
Reshape(transpose_5, init7_s2_2048_16) -> view_17
Gemm(view_17, p_layers_0_self_attn_o_proj_weight, transA=0, transB=1) -> mm_3
Reshape(mm_3, init7_s3_2_1024_16) -> view_18
Add(embedding, view_18) -> add_4
Pow(add_4, init1_s1_3) -> pow_2
ReduceMean(pow_2, init7_s1_-1, keepdims=1) -> mean_1
Add(mean_1, init1_s_5) -> add_5
Sqrt(add_5) -> _onx_sqrt02
Reciprocal(_onx_sqrt02) -> rsqrt_1
Mul(add_4, rsqrt_1) -> mul_8
Mul(p_layers_0_post_attention_layernorm_weight, mul_8) -> mul_9
Reshape(mul_9, init7_s2_2048_16) -> view_19
Gemm(view_19, p_layers_0_mlp_up_proj_weight, transA=0, transB=1) -> mm_5
Gemm(view_19, p_layers_0_mlp_gate_proj_weight, transA=0, transB=1) -> mm_4
Reshape(mm_4, init7_s3_2_1024_16) -> view_20
Sigmoid(view_20) -> _onx_sigmoid0
Reshape(_onx_sigmoid0, init7_s2_2048_16) -> Reshape2Of3PatternR__onx_sigmoid0
Mul(mm_4, Reshape2Of3PatternR__onx_sigmoid0) -> Reshape2Of3PatternL_silu
Mul(Reshape2Of3PatternL_silu, mm_5) -> view_23
Gemm(view_23, p_layers_0_mlp_down_proj_weight, transA=0, transB=1) -> mm_6
Reshape(mm_6, init7_s3_2_1024_16) -> view_24
Add(add_4, view_24) -> add_6
Pow(add_6, init1_s1_4) -> pow_3
ReduceMean(pow_3, init7_s1_-1, keepdims=1) -> mean_2
Add(mean_2, init1_s_6) -> add_7
Sqrt(add_7) -> _onx_sqrt03
Reciprocal(_onx_sqrt03) -> rsqrt_2
Mul(add_6, rsqrt_2) -> mul_11
Mul(p_norm_weight, mul_11) -> output_0
output: name='output_0' type=dtype('float32') shape=[2, 1024, 16]
output: name='output_1' type=dtype('float32') shape=[2, 2, 1024, 8]
output: name='output_2' type=dtype('float32') shape=[2, 2, 1024, 8]
Full Example¶
import torch
from transformers import AutoConfig, AutoModelForCausalLM
location = "meta-llama/Llama-2-7b-hf"
cahce_dir = "_cache"
l_config = AutoConfig.from_pretrained(
location, use_auth_token=use_auth_token, cache_dir=cache_dir
)
l_config.use_cache = True
llama = AutoModelForCausalLM.from_pretrained(
location,
use_auth_token=use_auth_token,
config=l_config,
torch_dtype=torch.float32,
cache_dir=cache_dir=cache_dir,
)
Llama 3¶
See Llama3.
import os
import time
import onnxruntime
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from experimental_experiment.torch_interpreter import to_onnx
model_id = "meta-llama/Meta-Llama-3-8B"
with torch.no_grad():
model = AutoModelForCausalLM.from_pretrained(model_id).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)
base_prompt = "Is the conversion to onnx going to work?"
base_inputs = tokenizer(base_prompt, return_tensors="pt") # .to("cpu")
input_ids = base_inputs.input_ids
expected = model(input_ids)
print(f"output type: {type(expected)}")
print(f"logits: {expected.logits.shape}, {expected.logits.dtype}")
print(
"start conversion... with input_ids", input_ids.dtype, input_ids.shape
)
begin = time.perf_counter()
large_onx = to_onnx(
model,
(input_ids,),
input_names=["x"],
verbose=1,
large_model=True,
# dynamic_shapes fails with transformers==4.37.2
# TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool
# dynamic_shapes={"x": {1: torch.export.Dim("length", min=2)}},
)
duration = time.perf_counter() - begin
print(f"conversion done in {duration}s")
folder = "test_zoo_export_llama3"
if not os.path.exists(folder):
os.mkdir(folder)
else:
for _ in os.listdir(folder):
os.remove(os.path.join(folder, _))
print(f"start saving in {folder!r}")
begin = time.perf_counter()
filename = os.path.join(folder, "llama3.onnx")
large_onx.save(filename)
duration = time.perf_counter() - begin
print(f"saving done in {duration}s with {len(os.listdir(folder))} files")
print(f"loading model {filename!r} with onnxruntime.")
begin = time.perf_counter()
sess = onnxruntime.InferenceSession(
filename, providers=["CPUExecutionProvider"]
)
print(f"done in {time.perf_counter() - begin}s")
print("running the first iteration")
begin = time.perf_counter()
name = large_onx.model_proto.graph.input[0].name
np_input = input_ids.detach().cpu().numpy()
got = sess.run(None, {name: np_input})
print(f"done in {time.perf_counter() - begin}s")
self.assertEqualArray(expected.logits, got[0], atol=1e-4)
N = 5
print(f"running {N} iterations with torch")
begin = time.perf_counter()
for i in range(N):
model(input_ids)
d = time.perf_counter() - begin
print(f"done in {d}s for torch")
print(f"running {N} iterations with onnxruntime")
begin = time.perf_counter()
for i in range(N):
sess.run(None, {name: np_input})
d = time.perf_counter() - begin
print(f"done in {d}s for onnxruntime")