301: Compares LLAMA exporters

The script compares the two exporters implemented in pytorch for a part of llama model. The model are compared after all optimizations were made with and onnxruntime.

To run the script:

python _doc/examples/plot_llama_diff_export --help

Some helpers

from experimental_experiment.args import get_parsed_args

script_args = get_parsed_args(
    "plot_llama_diff_export",
    description=__doc__,
    part=("attention", "one value among attention, decoder, model"),
    exporter=("dynamo", "one value among dynamo, custom"),
    ortopt=(1, "run onnxruntime optimization"),
    opset=(18, "onnx opset"),
    expose="part,exporter,ortopt,opset",
)

import contextlib
import os
import io
import warnings
import logging

try:
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        import onnxruntime

        has_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers()
except ImportError:
    print("onnxruntime not available.")
    import sys

    sys.exit(0)

import numpy as np
import onnx
from onnx_array_api.reference import compare_onnx_execution, ExtendedReferenceEvaluator
import torch
from experimental_experiment.ext_test_case import unit_test_going
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.xbuilder import OptimizationOptions
from experimental_experiment.convert.convert_helper import (
    optimize_model_proto_oxs,
    ort_optimize,
)
from experimental_experiment.torch_models.llama_helper import (
    get_llama_model,
    get_llama_attention,
    get_llama_decoder,
)
from experimental_experiment.torch_models.dump_helper import reorder_functions_in_proto

has_cuda = has_cuda and torch.cuda.is_available()
logging.disable(logging.ERROR)
provider = "cuda" if has_cuda else "cpu"

The exporting functions

print(f"part={script_args.part}")
print(f"exporter={script_args.exporter}")
ortopt = script_args.ortopt in (1, "1")
print(f"ortopt={ortopt}")
opset = int(script_args.opset)
print(f"opset={opset}")


def opt_filename(filename: str) -> str:
    name, ext = os.path.splitext(filename)
    return f"{name}.opt{ext}"


def export_script(filename, model, *args):
    with contextlib.redirect_stdout(io.StringIO()):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            torch.onnx.export(
                model, args, filename, input_names=["input"], opset_version=opset
            )
    if ortopt:
        onx = onnx.load(filename)
        ort_optimize(onx, opt_filename(filename), providers=provider)


def export_dynamo(filename, model, *args):
    with contextlib.redirect_stdout(io.StringIO()):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            export_output = torch.onnx.export(model, args, dynamo=True)
            model = export_output.model_proto
    try:
        new_model = optimize_model_proto_oxs(model)
    except ImportError as e:
        print("skipping optimization, missing package or failure:", e)
        new_model = model
    with open(filename, "wb") as f:
        f.write(new_model.SerializeToString())
    if ortopt:
        ort_optimize(new_model, opt_filename(filename), providers=provider)


def export_custom(filename, model, *args):
    new_model = to_onnx(
        model,
        tuple(args),
        input_names=[f"input{i}" for i in range(len(args))],
        options=OptimizationOptions(
            remove_unused=True,
            constant_folding=False,
        ),
        target_opset=opset,
    )
    with open(filename, "wb") as f:
        f.write(new_model.SerializeToString())
    if ortopt:
        ort_optimize(new_model, opt_filename(filename), providers=provider)
part=attention
exporter=dynamo
ortopt=True
opset=18

Model and data

if unit_test_going():
    kwargs = dict(input_dims=[(2, 1024)] * 2)
else:
    kwargs = dict(
        input_dims=[(2, 1024)] * 2,
        _attn_implementation="eager",
        num_hidden_layers=1,
        hidden_size=512,
        vocab_size=4000,
        intermediate_size=2000,
        max_position_embeddings=2048,
        num_attention_heads=8,
    )

if script_args.part == "attention":
    model, inputs = get_llama_attention(**kwargs)
elif script_args.part == "decoder":
    model, inputs = get_llama_decoder(**kwargs)
elif script_args.part == "model":
    model, inputs = get_llama_model(**kwargs)
else:
    raise RuntimeError(f"Unexpected value for part={script_args.part!r}")

print(f"simple run with {len(inputs)} inputs")
expected = model(*inputs[0])
if isinstance(expected, tuple):
    for t in expected:
        if not isinstance(t, tuple):
            print(f"eager worked {t.shape}, {t.dtype}")
        else:
            print(f"eager worked {type(t)}")
else:
    print(f"eager mode worked {expected.shape}, {expected.dtype}")
simple run with 2 inputs
eager mode worked torch.Size([2, 1024, 512]), torch.float32

Exporting

exporter = script_args.exporter
file1 = f"llama.{script_args.part}.script.onnx"
file2 = f"llama.{script_args.part}.{exporter}.onnx"

print("torch script exporter")
export_script(file1, model, *inputs[0])

if exporter == "dynamo":
    print("torch dynamo exporter")
    export_dynamo(file2, model, *inputs[0])
elif exporter == "custom":
    print("torch custom exporter")
    export_custom(file2, model, *inputs[0])
else:
    raise AssertionError(f"Unexpected value for exporter={exporter!r}.")
torch script exporter
torch dynamo exporter
Applied 7 of general pattern rewrite rules.

Verification

if ortopt:
    print("Using models optimized by onnxruntime")
    file1 = f"llama.{script_args.part}.script.opt.onnx"
    file2 = f"llama.{script_args.part}.{exporter}.opt.onnx"


providers = (
    ["CPUExecutionProvider"]
    if provider == "cpu"
    else [("CUDAExecutionProvider", {}), ("CPUExecutionProvider", {})]
)

model1 = onnx.load(file1)
model2 = onnx.load(file2)

feeds1, feeds2 = {}, {}
for i in range(len(inputs[0])):
    x = inputs[0][i].detach().numpy()
    feeds1[model1.graph.input[i].name] = x
    feeds2[model2.graph.input[i].name] = x

if ortopt:
    sess1 = onnxruntime.InferenceSession(file1, providers=providers)
    sess2 = onnxruntime.InferenceSession(file2, providers=providers)

    got1 = sess1.run(None, feeds1)
    got2 = sess2.run(None, feeds2)

    diff1 = np.abs(expected.detach().numpy() - got1[0]).max()
    diff2 = np.abs(expected.detach().numpy() - got2[0]).max()

    print(f"Error with the eager model and onnxruntime: {diff1}, {diff2}")
Using models optimized by onnxruntime
Error with the eager model and onnxruntime: 2.9178336262702942e-05, 2.9178336262702942e-05

Verification with the reference evaluator

reorder_functions_in_proto(file1)
reorder_functions_in_proto(file2)

sess1 = ExtendedReferenceEvaluator(file1)
try:
    sess2 = ExtendedReferenceEvaluator(file2)
except NotImplementedError as e:
    print(e)
    sess2 = None

got1 = sess1.run(None, feeds1)
got2 = got1 if sess2 is None else sess2.run(None, feeds2)

if isinstance(expected, tuple):
    diff1 = np.abs(expected[0].detach().numpy() - got1[0]).max()
    diff2 = np.abs(expected[0].detach().numpy() - got2[0]).max()
else:
    diff1 = np.abs(expected.detach().numpy() - got1[0]).max()
    diff2 = np.abs(expected.detach().numpy() - got2[0]).max()

print(f"Error with the eager model and the reference evaluator: {diff1}, {diff2}")
Error with the eager model and the reference evaluator: 4.0978193283081055e-08, 4.0978193283081055e-08

Comparison and execution

def clean_name(name):
    return name.replace(
        "_inlfunc_transformers_models_llama_modeling_llama_LlamaAttention", ""
    ).replace("_inlfunc_torch_nn_modules_linear_Linear", "")


if sess2 is not None:
    try:
        np_inputs = [i.detach().numpy() for i in inputs[0]]
        res1, res2, align, dc = compare_onnx_execution(
            model1, model2, inputs=np_inputs, verbose=1, raise_exc=False
        )
        for r in res2:
            r.name = clean_name(r.name)
        text = dc.to_str(res1, res2, align, column_size=90)
        print(text)
    except AssertionError as e:
        if "Unexpected type <class 'list'> for value, it must be a numpy array." not in str(e):
            raise
        print(e)
[compare_onnx_execution] execute with 3 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 60 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 60 results (first model)
[compare_onnx_execution] got 56 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 61 pairs
[compare_onnx_execution] done
001 = | INITIA float32  2:512x512            ZFXE                 onnx::MatMul_171                 | INITIA float32  2:512x512            ZFXE                 t
002 = | INITIA float32  2:512x512            XETY                 onnx::MatMul_172                 | INITIA float32  2:512x512            XETY                 t_1
003 = | INITIA float32  2:512x512            LEXW                 onnx::MatMul_173                 | INITIA float32  2:512x512            LEXW                 t_2
004 - | INITIA float32  2:512x512            BTJW                 onnx::MatMul_219                 |
005 = | INITIA int64    1:2                  GGAA                 splits                           | INITIA int64    1:2                  GGAA                 splits_token_14
006 - | INITIA int64    1:1                  BAAA                 /attention/Constant_25_output_0  |
007 = | INITIA int64    1:4                  CKIM                 /attention/Constant_2_output_0   | INITIA int64    1:4                  CKIM                 val_2
008 + |                                                                                            | INITIA float32  2:512x512            BTJW                 t_3
009 ~ | INITIA int64    1:1                  AAAA                 /attention/Constant_6_output_0   | INITIA int64                         BAAA                 node_aten_unsqueeze_46_dim_0
010 = | INITIA float32  3:1x32x1             DAAA                 /attention/rotary_emb/Expand_out | INITIA float32  3:1x32x1             DAAA                 _to_copy_2
011 - | INITIA int64    1:1                  KAAA                 /attention/Constant_24_output_0  |
012 = | INITIA int64    1:2                  GGAA                 splits_token_14                  | INITIA int64    1:2                  GGAA                 splits
013 - | INITIA int64    1:1                  DAAA                 const_transpose_optimizer_token_ |
014 = | INITIA int64    1:3                  CKZA                 /attention/Constant_26_output_0  | INITIA int64    1:3                  CKZA                 val_115
015 = | INPUT  float32  3:2x1024x512         HVJE                 input                            | INPUT  float32  3:2x1024x512         HVJE                 hidden_states
016 = | INPUT  float32  4:2x1x1024x1024      AAAA                 onnx::Slice_1                    | INPUT  float32  4:2x1x1024x1024      AAAA                 attention_mask
017 = | INPUT  int64    2:1x1024             KAQG                 onnx::Unsqueeze_2                | INPUT  int64    2:1x1024             KAQG                 position_ids
018 = | RESULT int64    3:1x1x1024           KAQG Unsqueeze       /attention/rotary_emb/Unsqueeze_ | RESULT int64    3:1x1x1024           KAQG Unsqueeze       unsqueeze_2
019 = | RESULT float32  3:1x1x1024           KAQG Cast            /attention/rotary_emb/Cast_outpu | RESULT float32  3:1x1x1024           KAQG Cast            _to_copy_1
020 = | RESULT float32  3:1x32x1024          EFXM MatMul          /attention/rotary_emb/MatMul_out | RESULT float32  3:1x32x1024          EFXM MatMul          matmul_3
021 = | RESULT float32  3:1x64x1024          JKJK Concat          /attention/rotary_emb/Concat     | RESULT float32  3:1x64x1024          JKJK Concat          node_Concat_64
022 = | RESULT float32  3:1x64x1024          RMRM Sin             /attention/rotary_emb/Sin        | RESULT float32  3:1x64x1024          RMRM Sin             node_Sin_66
023 = | RESULT float32  4:1x1x64x1024        RMRM Unsqueeze       /attention/Unsqueeze_1           | RESULT float32  4:1x1x64x1024        RMRM Unsqueeze       node_aten_unsqueeze_73_n2
024 = | RESULT float32  4:1x1024x1x64        GSEC Transpose       Transpose_token_7_out0           | RESULT float32  4:1x1024x1x64        GSEC Transpose       Transpose_token_7_out0
025 = | RESULT float32  3:2x1024x512         CVTW MatMul          /attention/k_proj/MatMul_output_ | RESULT float32  3:2x1024x512         CVTW MatMul          matmul_1
026 = | RESULT float32  4:2x1024x8x64        CVTW Reshape         /attention/Reshape_1_output_0    | RESULT float32  4:2x1024x8x64        CVTW Reshape         view_1
027 = | RESULT float32  4:2x1024x8x32        YMXO Split           /attention/Slice_2               | RESULT float32  4:2x1024x8x32        YMXO Split           node_Slice_114
028 = | RESULT float32  4:2x1024x8x32        EIXJ Split           /attention/Slice_3               | RESULT float32  4:2x1024x8x32        EIXJ Split           node_Slice_125
029 = | RESULT float32  4:2x1024x8x32        WSDR Neg             /attention/Neg_1                 | RESULT float32  4:2x1024x8x32        WSDR Neg             node_aten_neg_126_n0
030 = | RESULT float32  4:2x1024x8x64        TDAE Concat          /attention/Concat_1              | RESULT float32  4:2x1024x8x64        TDAE Concat          node_Concat_127
031 = | RESULT float32  4:2x1024x8x64        KRKR Mul             /attention/Mul_3                 | RESULT float32  4:2x1024x8x64        KRKR Mul             node_Mul_128
032 = | RESULT float32  3:1x64x1024          NHNH Cos             /attention/rotary_emb/Cos        | RESULT float32  3:1x64x1024          NHNH Cos             node_Cos_65
033 = | RESULT float32  4:1x1x64x1024        NHNH Unsqueeze       /attention/Unsqueeze             | RESULT float32  4:1x1x64x1024        NHNH Unsqueeze       node_aten_unsqueeze_72_n2
034 = | RESULT float32  4:1x1024x1x64        CJYF Transpose       Transpose_token_11_out0          | RESULT float32  4:1x1024x1x64        CJYF Transpose       Transpose_token_11_out0
035 = | RESULT float32  4:2x1024x8x64        EPKU Mul             /attention/Mul_2                 | RESULT float32  4:2x1024x8x64        EPKU Mul             node_Mul_103
036 = | RESULT float32  4:2x1024x8x64        OFUL Add             /attention/Add_1                 | RESULT float32  4:2x1024x8x64        OFUL Add             node_Add_129
037 = | RESULT float32  4:2x8x64x1024        LILV Transpose       /attention/Transpose_3_output_0  | RESULT float32  4:2x8x64x1024        LILV Transpose       transpose_4
038 = | RESULT float32  4:1x1x1024x64        GSEC Transpose       /attention/Unsqueeze_1_output_0  | RESULT float32  4:1x1x1024x64        GSEC Transpose       unsqueeze_4
039 = | RESULT float32  3:2x1024x512         QPPG MatMul          /attention/q_proj/MatMul_output_ | RESULT float32  3:2x1024x512         QPPG MatMul          matmul
040 = | RESULT float32  4:2x1024x8x64        QPPG Reshape         /attention/Reshape_output_0      | RESULT float32  4:2x1024x8x64        QPPG Reshape         view
041 = | RESULT float32  4:2x8x1024x64        MTAU Transpose       /attention/Transpose_output_0    | RESULT float32  4:2x8x1024x64        MTAU Transpose       transpose
042 = | RESULT float32  4:2x8x1024x32        YZNT Split           /attention/Slice_output_0        | RESULT float32  4:2x8x1024x32        YZNT Split           slice_4
043 = | RESULT float32  4:2x8x1024x32        PUOB Split           /attention/Slice_1_output_0      | RESULT float32  4:2x8x1024x32        PUOB Split           slice_5
044 = | RESULT float32  4:2x8x1024x32        LGMZ Neg             /attention/Neg_output_0          | RESULT float32  4:2x8x1024x32        LGMZ Neg             neg
045 = | RESULT float32  4:2x8x1024x64        KGYT Concat          /attention/Concat_output_0       | RESULT float32  4:2x8x1024x64        KGYT Concat          cat_1
046 = | RESULT float32  4:2x8x1024x64        LNZK Mul             /attention/Mul_1_output_0        | RESULT float32  4:2x8x1024x64        LNZK Mul             mul_3
047 = | RESULT float32  4:1x1x1024x64        CJYF Transpose       /attention/Unsqueeze_output_0    | RESULT float32  4:1x1x1024x64        CJYF Transpose       unsqueeze_3
048 = | RESULT float32  4:2x8x1024x64        QEUP Mul             /attention/Mul_output_0          | RESULT float32  4:2x8x1024x64        QEUP Mul             mul_2
049 = | RESULT float32  4:2x8x1024x64        CQTA Add             /attention/Add_output_0          | RESULT float32  4:2x8x1024x64        CQTA Add             add
050 = | RESULT float32  4:2x8x1024x1024      CYLH FusedMatMul     /attention/Div_output_0          | RESULT float32  4:2x8x1024x1024      CYLH FusedMatMul     div
051 - | RESULT float32  4:2x1x1024x1024      AAAA Slice           /attention/Slice_4_output_0      |
052 = | RESULT float32  4:2x8x1024x1024      CYLH Add             /attention/Add_2_output_0        | RESULT float32  4:2x8x1024x1024      CYLH Add             add_2
053 = | RESULT float32  4:2x8x1024x1024      OOOO Softmax         /attention/Softmax_output_0      | RESULT float32  4:2x8x1024x1024      OOOO Softmax         val_113
054 = | RESULT float32  3:2x1024x512         OOUT MatMul          /attention/v_proj/MatMul_output_ | RESULT float32  3:2x1024x512         OOUT MatMul          matmul_2
055 = | RESULT float32  4:2x1024x8x64        OOUT Reshape         /attention/Reshape_2_output_0    | RESULT float32  4:2x1024x8x64        OOUT Reshape         view_2
056 = | RESULT float32  4:2x8x1024x64        MSZO Transpose       /attention/Transpose_2_output_0  | RESULT float32  4:2x8x1024x64        MSZO Transpose       transpose_2
057 = | RESULT float32  4:2x8x1024x64        KPAJ MatMul          /attention/MatMul_1_output_0     | RESULT float32  4:2x8x1024x64        KPAJ MatMul          matmul_5
058 = | RESULT float32  4:2x1024x8x64        AZHC Transpose       /attention/Transpose_4_output_0  | RESULT float32  4:2x1024x8x64        AZHC Transpose       transpose_5
059 = | RESULT float32  3:2x1024x512         AZHC Reshape         /attention/Reshape_3_output_0    | RESULT float32  3:2x1024x512         AZHC Reshape         view_3
060 = | RESULT float32  3:2x1024x512         LPNQ MatMul          170                              | RESULT float32  3:2x1024x512         LPNQ MatMul          matmul_6
061 = | OUTPUT float32  3:2x1024x512         LPNQ                 170                              | OUTPUT float32  3:2x1024x512         LPNQ                 matmul_6

See plot_llama_diff_export for a better view.

Total running time of the script: (0 minutes 5.316 seconds)

Gallery generated by Sphinx-Gallery