Export a LLAMA model into ONNX

This script does not export a full llama model but a shorter one to be able to fast iterate on improvments. See LlamaConfig. The model is then converted into ONNX. It can be seen with Netron which can be also used through a VS Code Extension.

The model

import contextlib
import io
import os
import random
import warnings


def ids_tensor(shape, vocab_size, rng=None, name=None):
    #  Creates a random int32 tensor of the shape within the vocab size
    import torch

    if rng is None:
        rng = random.Random()

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.randint(0, vocab_size - 1))

    return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()


def get_llama_model(
    input_dims=[(2, 1024)],  # noqa: B006
    hidden_size=1024,  # 4096,
    num_hidden_layers=1,
    vocab_size=32000,
    intermediate_size=11008,
    max_position_embeddings=2048,
    num_attention_heads=4,  # 32,
    _attn_implementation="eager",
    with_mask: bool = True,
):
    import torch
    from transformers import LlamaConfig
    from transformers.models.llama.modeling_llama import LlamaModel

    config = LlamaConfig(
        num_hidden_layers=num_hidden_layers,
        vocab_size=vocab_size,
        hidden_size=hidden_size,
        intermediate_size=intermediate_size,
        max_position_embeddings=max_position_embeddings,
        num_attention_heads=num_attention_heads,
    )
    if _attn_implementation:
        config._attn_implementation = _attn_implementation

    class LlamaModelWrapper(torch.nn.Module):
        def __init__(self, config):
            super().__init__()
            self.model = LlamaModel(config)

        def forward(self, input_ids, attention_mask):
            model_output = self.model(input_ids, attention_mask=attention_mask)
            return model_output.to_tuple()

    def generate_example_inputs(batch: int, seq: int, vocab_size: int):
        input_ids = ids_tensor([batch, seq], vocab_size)
        input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))
        assert input_mask.dtype == torch.float32
        return input_ids, input_mask

    example_args_collection = []
    for b, s in input_dims:
        example_args_collection.append(generate_example_inputs(b, s, vocab_size))

    return LlamaModelWrapper(config), example_args_collection


print("creation of the model.")
model, example_args_collection = get_llama_model()
print("done.")
creation of the model.
done.

The conversion to ONNX

def export(model, args, filename):
    import torch

    with contextlib.redirect_stdout(io.StringIO()), warnings.catch_warnings():
        warnings.simplefilter("ignore")
        torch.onnx.export(
            model, args, filename, input_names=["input", "mask"], opset_version=17
        )


filename = "dump_llama.onnx"
print("conversion to ONNX in file {filename!r}")
export(model, example_args_collection[0], filename)
print("done.")
print(f"model size {os.stat(filename).st_size / 2**20} Mb.")
conversion to ONNX in file {filename!r}
done.
model size 278.05313777923584 Mb.

This gives the following in Netron:

../_images/llama1.png

Total running time of the script: (0 minutes 7.830 seconds)

Gallery generated by Sphinx-Gallery