Export a LLAMA model into ONNX

This script does not export a full llama model but a shorter one to be able to fast iterate on improvments. See LlamaConfig. The model is then converted into ONNX. It can be seen with Netron which can be also used through a VS Code Extension.

The model

import os
import random


def ids_tensor(shape, vocab_size, rng=None, name=None):
    #  Creates a random int32 tensor of the shape within the vocab size
    import torch

    if rng is None:
        rng = random.Random()

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.randint(0, vocab_size - 1))

    return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()


def get_llama_model(
    input_dims=[(2, 1024)],  # noqa: B006
    hidden_size=1024,  # 4096,
    num_hidden_layers=1,
    vocab_size=32000,
    intermediate_size=11008,
    max_position_embeddings=2048,
    num_attention_heads=4,  # 32,
    _attn_implementation="eager",
    with_mask: bool = True,
):
    import torch
    from transformers import LlamaConfig
    from transformers.models.llama.modeling_llama import LlamaModel

    config = LlamaConfig(
        num_hidden_layers=num_hidden_layers,
        vocab_size=vocab_size,
        hidden_size=hidden_size,
        intermediate_size=intermediate_size,
        max_position_embeddings=max_position_embeddings,
        num_attention_heads=num_attention_heads,
    )
    if _attn_implementation:
        config._attn_implementation = _attn_implementation

    class LlamaModelWrapper(torch.nn.Module):
        def __init__(self, config):
            super().__init__()
            self.model = LlamaModel(config)

        def forward(self, input_ids, attention_mask):
            model_output = self.model(input_ids, attention_mask=attention_mask)
            return model_output.to_tuple()

    def generate_example_inputs(batch: int, seq: int, vocab_size: int):
        input_ids = ids_tensor([batch, seq], vocab_size)
        input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))
        assert input_mask.dtype == torch.float32
        return input_ids, input_mask

    example_args_collection = []
    for b, s in input_dims:
        example_args_collection.append(generate_example_inputs(b, s, vocab_size))

    return LlamaModelWrapper(config), example_args_collection


print("creation of the model.")
model, example_args_collection = get_llama_model()
print("done.")
creation of the model.
~/vv/this312/lib/python3.12/site-packages/torch/compiler/__init__.py:148: FutureWarning: torch._dynamo.allow_in_graph is deprecated and will be removed in a future version. Use torch._dynamo.nonstrict_trace instead.
  return torch._dynamo.allow_in_graph(fn)
done.

The conversion to ONNX

def export(model, args, filename, dynamic_shapes):
    from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
    from onnx_diagnostic.torch_export_patches import bypass_export_some_errors

    with bypass_export_some_errors(patch_transformers=True):
        to_onnx(
            model,
            args,
            filename=filename,
            target_opset=18,
            dynamic_shapes=dynamic_shapes,
            export_options=ExportOptions(strict=False),
        )


filename = "dump_llama.onnx"
print(f"conversion to ONNX in file {filename!r}")
export(
    model,
    example_args_collection[0],
    filename,
    dynamic_shapes=({0: "batch", 1: "seq_length"}, {0: "batch", 1: "seq_length"}),
)
print("done.")
print(f"model size {os.stat(filename).st_size / 2**20} Mb.")
Traceback (most recent call last):
  File "~/github/teachcompute/_doc/examples/plot_export_model_onnx.py", line 115, in <module>
    export(
  File "~/github/teachcompute/_doc/examples/plot_export_model_onnx.py", line 103, in export
    to_onnx(
  File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py", line 1035, in to_onnx
    graph_module, builder, interpreter, mask_outputs = _make_builder_interpreter(
                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py", line 593, in _make_builder_interpreter
    exported_program = export_options.export(
                       ^^^^^^^^^^^^^^^^^^^^^^
  File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/export_options.py", line 708, in export
    exported_program = self._export(
                       ^^^^^^^^^^^^^
  File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/export_options.py", line 343, in _export
    return torch_export(
           ^^^^^^^^^^^^^
  File "~/github/experimental-experiment/experimental_experiment/export_helpers.py", line 146, in torch_export
    return torch.export.export(
           ^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 205, in export
    raise e
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 171, in export
    return _export(
           ^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1344, in wrapper
    raise e
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1310, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_utils_internal.py", line 96, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2512, in _export
    ep = _export_for_training(
         ^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1344, in wrapper
    raise e
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1310, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2300, in _export_for_training
    export_artifact = export_func(
                      ^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2229, in _non_strict_export
    aten_export_artifact = _to_aten_func(
                           ^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2006, in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2136, in _aot_export_non_strict
    gm, sig = aot_export(stack, wrapped_mod, args, kwargs=kwargs, **flags)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1914, in _make_fx_helper
    gm = make_fx(
         ^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2856, in wrapped
    return make_fx_tracer.trace(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2757, in trace
    return self._trace_inner(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2718, in _trace_inner
    t = dispatch_trace(
        ^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1271, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1554, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2290, in trace
    res = super().trace(root, concrete_args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 890, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1624, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
          ^^^^^^^^^^^
  File "<string>", line 1, in <lambda>
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1798, in wrapped_fn
    return tuple(flat_fn(*args))
                 ^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 192, in flat_fn
    tree_out = fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1520, in functional_call
    out = mod(*args[params_len:], **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 864, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2379, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 572, in call_module
    ret_val = forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 857, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2120, in forward
    tree_out = mod(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 864, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2379, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 572, in call_module
    ret_val = forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 857, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/github/teachcompute/_doc/examples/plot_export_model_onnx.py", line 73, in forward
    model_output = self.model(input_ids, attention_mask=attention_mask)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 864, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2379, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 572, in call_module
    ret_val = forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 857, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/github/transformers/src/transformers/utils/generic.py", line 952, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/github/transformers/src/transformers/utils/output_capturing.py", line 248, in wrapper
    outputs = func(self, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/github/transformers/src/transformers/models/llama/modeling_llama.py", line 399, in forward
    causal_mask = create_causal_mask(
                  ^^^^^^^^^^^^^^^^^^^
  File "~/github/transformers/src/transformers/utils/deprecation.py", line 171, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "~/github/transformers/src/transformers/masking_utils.py", line 983, in create_causal_mask
    causal_mask = mask_interface(
                  ^^^^^^^^^^^^^^^
TypeError: patched_eager_mask() missing 1 required positional argument: 'cache_position'

This gives the following in Netron:

../_images/llama1.png

Total running time of the script: (0 minutes 12.821 seconds)

Gallery generated by Sphinx-Gallery