Note
Go to the end to download the full example code
Export a LLAMA model into ONNX¶
This script does not export a full llama model but a shorter one to be able to fast iterate on improvments. See LlamaConfig. The model is then converted into ONNX. It can be seen with Netron which can be also used through a VS Code Extension.
The model¶
import contextlib
import io
import os
import random
import warnings
def ids_tensor(shape, vocab_size, rng=None, name=None):
# Creates a random int32 tensor of the shape within the vocab size
import torch
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
def get_llama_model(
input_dims=[(2, 1024)],
hidden_size=1024, # 4096,
num_hidden_layers=1,
vocab_size=32000,
intermediate_size=11008,
max_position_embeddings=2048,
num_attention_heads=4, # 32,
_attn_implementation="eager",
with_mask: bool = True,
):
import torch
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
config = LlamaConfig(
num_hidden_layers=num_hidden_layers,
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
max_position_embeddings=max_position_embeddings,
num_attention_heads=num_attention_heads,
)
if _attn_implementation:
config._attn_implementation = _attn_implementation
class LlamaModelWrapper(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.model = LlamaModel(config)
def forward(self, input_ids, attention_mask):
model_output = self.model(input_ids, attention_mask=attention_mask)
return model_output.to_tuple()
def generate_example_inputs(batch: int, seq: int, vocab_size: int):
input_ids = ids_tensor([batch, seq], vocab_size)
input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))
assert input_mask.dtype == torch.float32
return input_ids, input_mask
example_args_collection = []
for b, s in input_dims:
example_args_collection.append(generate_example_inputs(b, s, vocab_size))
return LlamaModelWrapper(config), example_args_collection
print("creation of the model.")
model, example_args_collection = get_llama_model()
print("done.")
creation of the model.
done.
The conversion to ONNX¶
def export(model, args, filename):
import torch
with contextlib.redirect_stdout(io.StringIO()):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
torch.onnx.export(
model, args, filename, input_names=["input", "mask"], opset_version=17
)
filename = "dump_llama.onnx"
print("conversion to ONNX in file {filename!r}")
export(model, example_args_collection[0], filename)
print("done.")
print(f"model size {os.stat(filename).st_size / 2**20} Mb.")
conversion to ONNX in file {filename!r}
done.
model size 276.0319871902466 Mb.
This gives the following in Netron:
Total running time of the script: (0 minutes 23.126 seconds)