301: Compares LLAMA exporters

The script compares the two exporters implemented in pytorch for a part of llama model. The model are compared after all optimizations were made with and onnxruntime.

To run the script:

python _doc/examples/plot_llama_diff_export --help

Some helpers

from experimental_experiment.args import get_parsed_args

script_args = get_parsed_args(
    "plot_llama_diff_export",
    description=__doc__,
    part=("attention", "one value among attention, decoder, model"),
    exporter=("dynamo", "one value among dynamo, custom"),
    ortopt=(1, "run onnxruntime optimization"),
    opset=(18, "onnx opset"),
    expose="part,exporter,ortopt,opset",
)

import contextlib
import os
import io
import warnings
import logging

try:
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        import onnxruntime

        has_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers()
except ImportError:
    print("onnxruntime not available.")
    import sys

    sys.exit(0)

import numpy as np
import onnx
from onnx_array_api.reference import compare_onnx_execution, ExtendedReferenceEvaluator
import torch
from experimental_experiment.ext_test_case import unit_test_going
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.xbuilder import OptimizationOptions
from experimental_experiment.convert.convert_helper import (
    optimize_model_proto,
    ort_optimize,
)
from experimental_experiment.torch_models.llama_helper import (
    get_llama_model,
    get_llama_attention,
    get_llama_decoder,
)
from experimental_experiment.torch_models.dump_helper import reorder_functions_in_proto

has_cuda = has_cuda and torch.cuda.is_available()
logging.disable(logging.ERROR)
provider = "cuda" if has_cuda else "cpu"

The exporting functions

print(f"part={script_args.part}")
print(f"exporter={script_args.exporter}")
ortopt = script_args.ortopt in (1, "1")
print(f"ortopt={ortopt}")
opset = int(script_args.opset)
print(f"opset={opset}")


def opt_filename(filename: str) -> str:
    name, ext = os.path.splitext(filename)
    return f"{name}.opt{ext}"


def export_script(filename, model, *args):
    with contextlib.redirect_stdout(io.StringIO()):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            torch.onnx.export(
                model, args, filename, input_names=["input"], opset_version=opset
            )
    if ortopt:
        onx = onnx.load(filename)
        ort_optimize(onx, opt_filename(filename), providers=provider)


def export_dynamo(filename, model, *args):
    with contextlib.redirect_stdout(io.StringIO()):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            export_output = torch.onnx.dynamo_export(model, *args)
            model = export_output.model_proto
    try:
        new_model = optimize_model_proto(model)
    except ImportError as e:
        print("skipping optimization, missing package or failure:", e)
        new_model = model
    with open(filename, "wb") as f:
        f.write(new_model.SerializeToString())
    if ortopt:
        ort_optimize(new_model, opt_filename(filename), providers=provider)


def export_custom(filename, model, *args):
    new_model = to_onnx(
        model,
        tuple(args),
        input_names=[f"input{i}" for i in range(len(args))],
        options=OptimizationOptions(
            remove_unused=True,
            constant_folding=False,
        ),
        target_opset=opset,
    )
    with open(filename, "wb") as f:
        f.write(new_model.SerializeToString())
    if ortopt:
        ort_optimize(new_model, opt_filename(filename), providers=provider)
part=attention
exporter=dynamo
ortopt=True
opset=18

Model and data

if unit_test_going():
    kwargs = dict(input_dims=[(2, 1024)] * 2)
else:
    kwargs = dict(
        input_dims=[(2, 1024)] * 2,
        _attn_implementation="eager",
        num_hidden_layers=1,
        hidden_size=512,
        vocab_size=4000,
        intermediate_size=2000,
        max_position_embeddings=2048,
        num_attention_heads=8,
    )

if script_args.part == "attention":
    model, inputs = get_llama_attention(**kwargs)
elif script_args.part == "decoder":
    model, inputs = get_llama_decoder(**kwargs)
elif script_args.part == "model":
    model, inputs = get_llama_model(**kwargs)
else:
    raise RuntimeError(f"Unexpected value for part={script_args.part!r}")

print(f"simple run with {len(inputs)} inputs")
expected = model(*inputs[0])
if isinstance(expected, tuple):
    for t in expected:
        if not isinstance(t, tuple):
            print(f"eager worked {t.shape}, {t.dtype}")
        else:
            print(f"eager worked {type(t)}")
else:
    print(f"eager mode worked {expected.shape}, {expected.dtype}")
simple run with 2 inputs
eager mode worked torch.Size([2, 1024, 512]), torch.float32

Exporting

exporter = script_args.exporter
file1 = f"llama.{script_args.part}.script.onnx"
file2 = f"llama.{script_args.part}.{exporter}.onnx"

print("torch script exporter")
export_script(file1, model, *inputs[0])

if exporter == "dynamo":
    print("torch dynamo exporter")
    export_dynamo(file2, model, *inputs[0])
elif exporter == "custom":
    print("torch custom exporter")
    export_custom(file2, model, *inputs[0])
else:
    raise AssertionError(f"Unexpected value for exporter={exporter!r}.")
torch script exporter
torch dynamo exporter
Applied 6 pattern rewrite rules.
Applied 0 pattern rewrite rules.

Verification

if ortopt:
    print("Using models optimized by onnxruntime")
    file1 = f"llama.{script_args.part}.script.opt.onnx"
    file2 = f"llama.{script_args.part}.{exporter}.opt.onnx"


providers = (
    ["CPUExecutionProvider"]
    if provider == "cpu"
    else [("CUDAExecutionProvider", {}), ("CPUExecutionProvider", {})]
)

model1 = onnx.load(file1)
model2 = onnx.load(file2)

feeds1, feeds2 = {}, {}
for i in range(len(inputs[0])):
    x = inputs[0][i].detach().numpy()
    feeds1[model1.graph.input[i].name] = x
    feeds2[model2.graph.input[i].name] = x

if ortopt:
    sess1 = onnxruntime.InferenceSession(file1, providers=providers)
    sess2 = onnxruntime.InferenceSession(file2, providers=providers)

    got1 = sess1.run(None, feeds1)
    got2 = sess2.run(None, feeds2)

    diff1 = np.abs(expected.detach().numpy() - got1[0]).max()
    diff2 = np.abs(expected.detach().numpy() - got2[0]).max()

    print(f"Error with the eager model and onnxruntime: {diff1}, {diff2}")
Using models optimized by onnxruntime
Error with the eager model and onnxruntime: 6.705522537231445e-08, 6.705522537231445e-08

Verification with the reference evaluator

reorder_functions_in_proto(file1)
reorder_functions_in_proto(file2)

sess1 = ExtendedReferenceEvaluator(file1)
try:
    sess2 = ExtendedReferenceEvaluator(file2)
except NotImplementedError as e:
    print(e)
    sess2 = None

got1 = sess1.run(None, feeds1)
got2 = got1 if sess2 is None else sess2.run(None, feeds2)

if isinstance(expected, tuple):
    diff1 = np.abs(expected[0].detach().numpy() - got1[0]).max()
    diff2 = np.abs(expected[0].detach().numpy() - got2[0]).max()
else:
    diff1 = np.abs(expected.detach().numpy() - got1[0]).max()
    diff2 = np.abs(expected.detach().numpy() - got2[0]).max()

print(f"Error with the eager model and the reference evaluator: {diff1}, {diff2}")
Error with the eager model and the reference evaluator: 4.0978193283081055e-08, 4.0978193283081055e-08

Comparison and execution

def clean_name(name):
    return name.replace(
        "_inlfunc_transformers_models_llama_modeling_llama_LlamaAttention", ""
    ).replace("_inlfunc_torch_nn_modules_linear_Linear", "")


if sess2 is not None:
    try:
        np_inputs = [i.detach().numpy() for i in inputs[0]]
        res1, res2, align, dc = compare_onnx_execution(
            model1, model2, inputs=np_inputs, verbose=1, raise_exc=False
        )
        for r in res2:
            r.name = clean_name(r.name)
        text = dc.to_str(res1, res2, align, column_size=90)
        print(text)
    except AssertionError as e:
        if (
            "Unexpected type <class 'list'> for value, it must be a numpy array."
            not in str(e)
        ):
            raise
        print(e)
[compare_onnx_execution] execute with 3 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 51 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 51 results (first model)
[compare_onnx_execution] got 60 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 62 pairs
[compare_onnx_execution] done
001 ~ | INITIA float32  2:512x512            GXWA                 onnx::MatMul_131                 | INITIA int64    1:2                  BKAA                 ortshared_7_1_2_0_token_109
002 + |                                                                                            | INITIA float32  2:1024x64            GSEC                 _val_32__1
003 + |                                                                                            | INITIA int64    1:3                  CKSA                 ortshared_7_1_3_0_token_108
004 + |                                                                                            | INITIA int64                         BAAA                 ortshared_7_0_1_0_token_107
005 + |                                                                                            | INITIA float32                       BAAA                 ortshared_1_0_1_1_token_116
006 ~ | INITIA float32  2:512x512            HGDF                 onnx::MatMul_132                 | INITIA float32  2:512x512            LFJJ                 torch_nn_modules_linear_Linear_a
007 ~ | INITIA float32  2:512x512            AXYW                 onnx::MatMul_133                 | INITIA float32  2:512x512            GXWA                 torch_nn_modules_linear_Linear_a
008 ~ | INITIA float32  2:512x512            LFJJ                 onnx::MatMul_169                 | INITIA float32  2:512x512            HGDF                 torch_nn_modules_linear_Linear_a
009 + |                                                                                            | INITIA float32  2:512x512            AXYW                 torch_nn_modules_linear_Linear_a
010 ~ | INITIA int64    1:4                  CKIM                 ortshared_7_1_4_0_token_76       | INITIA int64    1:2                  GGAA                 splits_token_118
011 ~ | INITIA int64    1:2                  GGAA                 splits                           | INITIA int64                         ZAAA                 ortshared_7_0_1_1_token_114
012 ~ | INITIA int64    1:3                  CKSA                 ortshared_7_1_3_0_token_80       | INITIA int64    1:4                  CKIM                 ortshared_7_1_4_0_token_113
013 = | INITIA float32  2:1024x64            CJYF                 /attention/rotary_emb/Constant_o | INITIA float32  2:1024x64            CJYF                 _val_22__1
014 - | INITIA float32  2:1024x64            GSEC                 /attention/rotary_emb/Constant_1 |
015 = | INITIA int64    1:2                  GGAA                 splits_token_81                  | INITIA int64    1:2                  GGAA                 splits
016 - | INITIA int64    1:1                  BAAA                 ortshared_7_1_1_3_token_78       |
017 = | INPUT  float32  3:2x1024x512         ULQF                 input                            | INPUT  float32  3:2x1024x512         ULQF                 l_hidden_states_
018 = | INPUT  float32  4:2x1x1024x1024      AAAA                 onnx::Add_1                      | INPUT  float32  4:2x1x1024x1024      AAAA                 l_attention_mask_
019 = | INPUT  int64    2:1x1024             KAQG                 position_ids                     | INPUT  int64    2:1x1024             KAQG                 l_position_ids_
020 + |                                                                                            | RESULT int64    2:1x1024             KAQG Expand          _val_35__1
021 + |                                                                                            | RESULT int64    3:1x1024x1           KAQG Unsqueeze       _val_37__1
022 + |                                                                                            | RESULT int64    3:1x1024x1           KAQG Concat          _val_38__1
023 ~ | RESULT float32  3:1x1024x64          GSEC Gather          /attention/Gather_1_output_0     | RESULT float32  3:1x1024x64          GSEC GatherND        _val_39__1
024 = | RESULT float32  4:1x1x1024x64        GSEC Unsqueeze       /attention/Unsqueeze_1_output_0  | RESULT float32  4:1x1x1024x64        GSEC Unsqueeze       aten_unsqueeze_65_n2__1
025 = | RESULT float32  4:1x1024x1x64        GSEC Transpose       Transpose_token_4_out0           | RESULT float32  4:1x1024x1x64        GSEC Transpose       Transpose_token_5_out0
026 = | RESULT float32  3:2x1024x512         KRRM MatMul          /attention/k_proj/MatMul_output_ | RESULT float32  3:2x1024x512         KRRM MatMul          attention_k_proj_1__1
027 = | RESULT float32  4:2x1024x8x64        KRRM Reshape         /attention/Reshape_1_output_0    | RESULT float32  4:2x1024x8x64        KRRM Reshape         view_7__1
028 = | RESULT float32  4:2x1024x8x32        YVML Split           /attention/Slice_2               | RESULT float32  4:2x1024x8x32        YVML Split           Slice_123__1
029 = | RESULT float32  4:2x1024x8x32        MWFB Split           /attention/Slice_3               | RESULT float32  4:2x1024x8x32        MWFB Split           Slice_140__1
030 = | RESULT float32  4:2x1024x8x32        OEVZ Neg             /attention/Neg_1                 | RESULT float32  4:2x1024x8x32        OEVZ Neg             aten_neg_141_n0__1
031 = | RESULT float32  4:2x1024x8x64        NZHK Concat          /attention/Concat_1              | RESULT float32  4:2x1024x8x64        NZHK Concat          aten_cat_143_n0__1
032 = | RESULT float32  4:2x1024x8x64        VUBG Mul             /attention/Mul_3                 | RESULT float32  4:2x1024x8x64        VUBG Mul             aten_mul_144_n0__1
033 ~ | RESULT float32  3:1x1024x64          CJYF Gather          /attention/Gather_output_0       | RESULT float32  3:1x1024x64          CJYF GatherND        _val_29__1
034 = | RESULT float32  4:1x1x1024x64        CJYF Unsqueeze       /attention/Unsqueeze_output_0    | RESULT float32  4:1x1x1024x64        CJYF Unsqueeze       aten_unsqueeze_55_n2__1
035 = | RESULT float32  4:1x1024x1x64        CJYF Transpose       Transpose_token_6_out0           | RESULT float32  4:1x1024x1x64        CJYF Transpose       Transpose_token_8_out0
036 = | RESULT float32  4:2x1024x8x64        GRNX Mul             /attention/Mul_2                 | RESULT float32  4:2x1024x8x64        GRNX Mul             aten_mul_106_n0__1
037 = | RESULT float32  4:2x1024x8x64        BLPD Add             /attention/Add_1                 | RESULT float32  4:2x1024x8x64        BLPD Add             n3__3
038 = | RESULT float32  4:2x8x64x1024        EJHL Transpose       /attention/Transpose_3_output_0  | RESULT float32  4:2x8x64x1024        EJHL Transpose       transpose_3__1
039 + |                                                                                            | RESULT float32  4:1x1x1024x64        GSEC Transpose       unsqueeze_1__1
040 = | RESULT float32  3:2x1024x512         OSYT MatMul          /attention/q_proj/MatMul_output_ | RESULT float32  3:2x1024x512         OSYT MatMul          attention_q_proj_1__1
041 = | RESULT float32  4:2x1024x8x64        OSYT Reshape         /attention/Reshape_output_0      | RESULT float32  4:2x1024x8x64        OSYT Reshape         view_6__1
042 = | RESULT float32  4:2x8x1024x64        HAKH Transpose       /attention/Transpose_output_0    | RESULT float32  4:2x8x1024x64        HAKH Transpose       transpose__1
043 = | RESULT float32  4:2x8x1024x32        EVBF Split           /attention/Slice_output_0        | RESULT float32  4:2x8x1024x32        EVBF Split           slice_3__1
044 = | RESULT float32  4:2x8x1024x32        CEID Split           /attention/Slice_1_output_0      | RESULT float32  4:2x8x1024x32        CEID Split           slice_4__1
045 = | RESULT float32  4:2x8x1024x32        YWSX Neg             /attention/Neg_output_0          | RESULT float32  4:2x8x1024x32        YWSX Neg             neg__1
046 = | RESULT float32  4:2x8x1024x64        DSTB Concat          /attention/Concat_output_0       | RESULT float32  4:2x8x1024x64        DSTB Concat          cat__1
047 = | RESULT float32  4:2x8x1024x64        NHCJ Mul             /attention/Mul_1_output_0        | RESULT float32  4:2x8x1024x64        NHCJ Mul             mul_1__1
048 + |                                                                                            | RESULT float32  4:1x1x1024x64        CJYF Transpose       unsqueeze__1
049 = | RESULT float32  4:2x8x1024x64        IUYZ Mul             /attention/Mul_output_0          | RESULT float32  4:2x8x1024x64        IUYZ Mul             mul__1
050 = | RESULT float32  4:2x8x1024x64        VCBI Add             /attention/Add_output_0          | RESULT float32  4:2x8x1024x64        VCBI Add             add__1
051 = | RESULT float32  4:2x8x1024x1024      AWFA FusedMatMul     /attention/Div_output_0          | RESULT float32  4:2x8x1024x1024      AWFA FusedMatMul     div__1
052 + |                                                                                            | RESULT float32  4:2x1x1024x1024      AAAA Mul             other_1__4
053 = | RESULT float32  4:2x8x1024x1024      AWFA Add             /attention/Add_2_output_0        | RESULT float32  4:2x8x1024x1024      AWFA Add             add_2__1
054 = | RESULT float32  4:2x8x1024x1024      NNON Softmax         /attention/Softmax_output_0      | RESULT float32  4:2x8x1024x1024      NNON Softmax         _softmax__1
055 = | RESULT float32  3:2x1024x512         HMLX MatMul          /attention/v_proj/MatMul_output_ | RESULT float32  3:2x1024x512         HMLX MatMul          attention_v_proj_1__1
056 = | RESULT float32  4:2x1024x8x64        HMLX Reshape         /attention/Reshape_2_output_0    | RESULT float32  4:2x1024x8x64        HMLX Reshape         view_8__1
057 = | RESULT float32  4:2x8x1024x64        FOKY Transpose       /attention/Transpose_2_output_0  | RESULT float32  4:2x8x1024x64        FOKY Transpose       transpose_2__1
058 = | RESULT float32  4:2x8x1024x64        PLNS MatMul          /attention/MatMul_1_output_0     | RESULT float32  4:2x8x1024x64        PLNS MatMul          view_14__1
059 = | RESULT float32  4:2x1024x8x64        BZDC Transpose       /attention/Transpose_4_output_0  | RESULT float32  4:2x1024x8x64        BZDC Transpose       transpose_4__1
060 = | RESULT float32  3:2x1024x512         BZDC Reshape         /attention/Reshape_3_output_0    | RESULT float32  3:2x1024x512         BZDC Reshape         view_15__1
061 = | RESULT float32  3:2x1024x512         OPNS MatMul          130                              | RESULT float32  3:2x1024x512         OPNS MatMul          attention_1
062 = | OUTPUT float32  3:2x1024x512         OPNS                 130                              | OUTPUT float32  3:2x1024x512         OPNS                 attention_1

See plot_llama_diff_export for a better view.

Total running time of the script: (0 minutes 31.156 seconds)

Gallery generated by Sphinx-Gallery