Note
Go to the end to download the full example code
301: Compares LLAMA exporters¶
The script compares the two exporters implemented in pytorch for a part of llama model. The model are compared after all optimizations were made with and onnxruntime.
TorchScript-based ONNX Exporter, let’s call it script
TorchDynamo-based ONNX Exporter, let’s call it dynamo
To run the script:
python _doc/examples/plot_llama_diff_export --help
Some helpers¶
from experimental_experiment.args import get_parsed_args
script_args = get_parsed_args(
"plot_llama_diff_export",
description=__doc__,
part=("attention", "one value among attention, decoder, model"),
exporter=("dynamo", "one value among dynamo, custom"),
ortopt=(1, "run onnxruntime optimization"),
opset=(18, "onnx opset"),
expose="part,exporter,ortopt,opset",
)
import contextlib
import os
import io
import warnings
import logging
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import onnxruntime
has_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers()
except ImportError:
print("onnxruntime not available.")
import sys
sys.exit(0)
import numpy as np
import onnx
from onnx_array_api.reference import compare_onnx_execution, ExtendedReferenceEvaluator
import torch
from experimental_experiment.ext_test_case import unit_test_going
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.xbuilder import OptimizationOptions
from experimental_experiment.convert.convert_helper import (
optimize_model_proto,
ort_optimize,
)
from experimental_experiment.torch_models.llama_helper import (
get_llama_model,
get_llama_attention,
get_llama_decoder,
)
from experimental_experiment.torch_models.dump_helper import reorder_functions_in_proto
has_cuda = has_cuda and torch.cuda.is_available()
logging.disable(logging.ERROR)
provider = "cuda" if has_cuda else "cpu"
The exporting functions¶
print(f"part={script_args.part}")
print(f"exporter={script_args.exporter}")
ortopt = script_args.ortopt in (1, "1")
print(f"ortopt={ortopt}")
opset = int(script_args.opset)
print(f"opset={opset}")
def opt_filename(filename: str) -> str:
name, ext = os.path.splitext(filename)
return f"{name}.opt{ext}"
def export_script(filename, model, *args):
with contextlib.redirect_stdout(io.StringIO()):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
torch.onnx.export(
model, args, filename, input_names=["input"], opset_version=opset
)
if ortopt:
onx = onnx.load(filename)
ort_optimize(onx, opt_filename(filename), providers=provider)
def export_dynamo(filename, model, *args):
with contextlib.redirect_stdout(io.StringIO()):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
export_output = torch.onnx.dynamo_export(model, *args)
model = export_output.model_proto
try:
new_model = optimize_model_proto(model)
except ImportError as e:
print("skipping optimization, missing package or failure:", e)
new_model = model
with open(filename, "wb") as f:
f.write(new_model.SerializeToString())
if ortopt:
ort_optimize(new_model, opt_filename(filename), providers=provider)
def export_custom(filename, model, *args):
new_model = to_onnx(
model,
tuple(args),
input_names=[f"input{i}" for i in range(len(args))],
options=OptimizationOptions(
remove_unused=True,
constant_folding=False,
),
target_opset=opset,
)
with open(filename, "wb") as f:
f.write(new_model.SerializeToString())
if ortopt:
ort_optimize(new_model, opt_filename(filename), providers=provider)
part=attention
exporter=dynamo
ortopt=True
opset=18
Model and data¶
if unit_test_going():
kwargs = dict(input_dims=[(2, 1024)] * 2)
else:
kwargs = dict(
input_dims=[(2, 1024)] * 2,
_attn_implementation="eager",
num_hidden_layers=1,
hidden_size=512,
vocab_size=4000,
intermediate_size=2000,
max_position_embeddings=2048,
num_attention_heads=8,
)
if script_args.part == "attention":
model, inputs = get_llama_attention(**kwargs)
elif script_args.part == "decoder":
model, inputs = get_llama_decoder(**kwargs)
elif script_args.part == "model":
model, inputs = get_llama_model(**kwargs)
else:
raise RuntimeError(f"Unexpected value for part={script_args.part!r}")
print(f"simple run with {len(inputs)} inputs")
expected = model(*inputs[0])
if isinstance(expected, tuple):
for t in expected:
if not isinstance(t, tuple):
print(f"eager worked {t.shape}, {t.dtype}")
else:
print(f"eager worked {type(t)}")
else:
print(f"eager mode worked {expected.shape}, {expected.dtype}")
simple run with 2 inputs
eager mode worked torch.Size([2, 1024, 512]), torch.float32
Exporting¶
exporter = script_args.exporter
file1 = f"llama.{script_args.part}.script.onnx"
file2 = f"llama.{script_args.part}.{exporter}.onnx"
print("torch script exporter")
export_script(file1, model, *inputs[0])
if exporter == "dynamo":
print("torch dynamo exporter")
export_dynamo(file2, model, *inputs[0])
elif exporter == "custom":
print("torch custom exporter")
export_custom(file2, model, *inputs[0])
else:
raise AssertionError(f"Unexpected value for exporter={exporter!r}.")
torch script exporter
torch dynamo exporter
Applied 6 pattern rewrite rules.
Applied 0 pattern rewrite rules.
Verification¶
if ortopt:
print("Using models optimized by onnxruntime")
file1 = f"llama.{script_args.part}.script.opt.onnx"
file2 = f"llama.{script_args.part}.{exporter}.opt.onnx"
providers = (
["CPUExecutionProvider"]
if provider == "cpu"
else [("CUDAExecutionProvider", {}), ("CPUExecutionProvider", {})]
)
model1 = onnx.load(file1)
model2 = onnx.load(file2)
feeds1, feeds2 = {}, {}
for i in range(len(inputs[0])):
x = inputs[0][i].detach().numpy()
feeds1[model1.graph.input[i].name] = x
feeds2[model2.graph.input[i].name] = x
if ortopt:
sess1 = onnxruntime.InferenceSession(file1, providers=providers)
sess2 = onnxruntime.InferenceSession(file2, providers=providers)
got1 = sess1.run(None, feeds1)
got2 = sess2.run(None, feeds2)
diff1 = np.abs(expected.detach().numpy() - got1[0]).max()
diff2 = np.abs(expected.detach().numpy() - got2[0]).max()
print(f"Error with the eager model and onnxruntime: {diff1}, {diff2}")
Using models optimized by onnxruntime
Error with the eager model and onnxruntime: 6.705522537231445e-08, 6.705522537231445e-08
Verification with the reference evaluator¶
reorder_functions_in_proto(file1)
reorder_functions_in_proto(file2)
sess1 = ExtendedReferenceEvaluator(file1)
try:
sess2 = ExtendedReferenceEvaluator(file2)
except NotImplementedError as e:
print(e)
sess2 = None
got1 = sess1.run(None, feeds1)
got2 = got1 if sess2 is None else sess2.run(None, feeds2)
if isinstance(expected, tuple):
diff1 = np.abs(expected[0].detach().numpy() - got1[0]).max()
diff2 = np.abs(expected[0].detach().numpy() - got2[0]).max()
else:
diff1 = np.abs(expected.detach().numpy() - got1[0]).max()
diff2 = np.abs(expected.detach().numpy() - got2[0]).max()
print(f"Error with the eager model and the reference evaluator: {diff1}, {diff2}")
Error with the eager model and the reference evaluator: 4.0978193283081055e-08, 4.0978193283081055e-08
Comparison and execution¶
def clean_name(name):
return name.replace(
"_inlfunc_transformers_models_llama_modeling_llama_LlamaAttention", ""
).replace("_inlfunc_torch_nn_modules_linear_Linear", "")
if sess2 is not None:
try:
np_inputs = [i.detach().numpy() for i in inputs[0]]
res1, res2, align, dc = compare_onnx_execution(
model1, model2, inputs=np_inputs, verbose=1, raise_exc=False
)
for r in res2:
r.name = clean_name(r.name)
text = dc.to_str(res1, res2, align, column_size=90)
print(text)
except AssertionError as e:
if (
"Unexpected type <class 'list'> for value, it must be a numpy array."
not in str(e)
):
raise
print(e)
[compare_onnx_execution] execute with 3 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 51 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 51 results (first model)
[compare_onnx_execution] got 60 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 62 pairs
[compare_onnx_execution] done
001 ~ | INITIA float32 2:512x512 GXWA onnx::MatMul_131 | INITIA int64 1:2 BKAA ortshared_7_1_2_0_token_109
002 + | | INITIA float32 2:1024x64 GSEC _val_32__1
003 + | | INITIA int64 1:3 CKSA ortshared_7_1_3_0_token_108
004 + | | INITIA int64 BAAA ortshared_7_0_1_0_token_107
005 + | | INITIA float32 BAAA ortshared_1_0_1_1_token_116
006 ~ | INITIA float32 2:512x512 HGDF onnx::MatMul_132 | INITIA float32 2:512x512 LFJJ torch_nn_modules_linear_Linear_a
007 ~ | INITIA float32 2:512x512 AXYW onnx::MatMul_133 | INITIA float32 2:512x512 GXWA torch_nn_modules_linear_Linear_a
008 ~ | INITIA float32 2:512x512 LFJJ onnx::MatMul_169 | INITIA float32 2:512x512 HGDF torch_nn_modules_linear_Linear_a
009 + | | INITIA float32 2:512x512 AXYW torch_nn_modules_linear_Linear_a
010 ~ | INITIA int64 1:4 CKIM ortshared_7_1_4_0_token_76 | INITIA int64 1:2 GGAA splits_token_118
011 ~ | INITIA int64 1:2 GGAA splits | INITIA int64 ZAAA ortshared_7_0_1_1_token_114
012 ~ | INITIA int64 1:3 CKSA ortshared_7_1_3_0_token_80 | INITIA int64 1:4 CKIM ortshared_7_1_4_0_token_113
013 = | INITIA float32 2:1024x64 CJYF /attention/rotary_emb/Constant_o | INITIA float32 2:1024x64 CJYF _val_22__1
014 - | INITIA float32 2:1024x64 GSEC /attention/rotary_emb/Constant_1 |
015 = | INITIA int64 1:2 GGAA splits_token_81 | INITIA int64 1:2 GGAA splits
016 - | INITIA int64 1:1 BAAA ortshared_7_1_1_3_token_78 |
017 = | INPUT float32 3:2x1024x512 ULQF input | INPUT float32 3:2x1024x512 ULQF l_hidden_states_
018 = | INPUT float32 4:2x1x1024x1024 AAAA onnx::Add_1 | INPUT float32 4:2x1x1024x1024 AAAA l_attention_mask_
019 = | INPUT int64 2:1x1024 KAQG position_ids | INPUT int64 2:1x1024 KAQG l_position_ids_
020 + | | RESULT int64 2:1x1024 KAQG Expand _val_35__1
021 + | | RESULT int64 3:1x1024x1 KAQG Unsqueeze _val_37__1
022 + | | RESULT int64 3:1x1024x1 KAQG Concat _val_38__1
023 ~ | RESULT float32 3:1x1024x64 GSEC Gather /attention/Gather_1_output_0 | RESULT float32 3:1x1024x64 GSEC GatherND _val_39__1
024 = | RESULT float32 4:1x1x1024x64 GSEC Unsqueeze /attention/Unsqueeze_1_output_0 | RESULT float32 4:1x1x1024x64 GSEC Unsqueeze aten_unsqueeze_65_n2__1
025 = | RESULT float32 4:1x1024x1x64 GSEC Transpose Transpose_token_4_out0 | RESULT float32 4:1x1024x1x64 GSEC Transpose Transpose_token_5_out0
026 = | RESULT float32 3:2x1024x512 KRRM MatMul /attention/k_proj/MatMul_output_ | RESULT float32 3:2x1024x512 KRRM MatMul attention_k_proj_1__1
027 = | RESULT float32 4:2x1024x8x64 KRRM Reshape /attention/Reshape_1_output_0 | RESULT float32 4:2x1024x8x64 KRRM Reshape view_7__1
028 = | RESULT float32 4:2x1024x8x32 YVML Split /attention/Slice_2 | RESULT float32 4:2x1024x8x32 YVML Split Slice_123__1
029 = | RESULT float32 4:2x1024x8x32 MWFB Split /attention/Slice_3 | RESULT float32 4:2x1024x8x32 MWFB Split Slice_140__1
030 = | RESULT float32 4:2x1024x8x32 OEVZ Neg /attention/Neg_1 | RESULT float32 4:2x1024x8x32 OEVZ Neg aten_neg_141_n0__1
031 = | RESULT float32 4:2x1024x8x64 NZHK Concat /attention/Concat_1 | RESULT float32 4:2x1024x8x64 NZHK Concat aten_cat_143_n0__1
032 = | RESULT float32 4:2x1024x8x64 VUBG Mul /attention/Mul_3 | RESULT float32 4:2x1024x8x64 VUBG Mul aten_mul_144_n0__1
033 ~ | RESULT float32 3:1x1024x64 CJYF Gather /attention/Gather_output_0 | RESULT float32 3:1x1024x64 CJYF GatherND _val_29__1
034 = | RESULT float32 4:1x1x1024x64 CJYF Unsqueeze /attention/Unsqueeze_output_0 | RESULT float32 4:1x1x1024x64 CJYF Unsqueeze aten_unsqueeze_55_n2__1
035 = | RESULT float32 4:1x1024x1x64 CJYF Transpose Transpose_token_6_out0 | RESULT float32 4:1x1024x1x64 CJYF Transpose Transpose_token_8_out0
036 = | RESULT float32 4:2x1024x8x64 GRNX Mul /attention/Mul_2 | RESULT float32 4:2x1024x8x64 GRNX Mul aten_mul_106_n0__1
037 = | RESULT float32 4:2x1024x8x64 BLPD Add /attention/Add_1 | RESULT float32 4:2x1024x8x64 BLPD Add n3__3
038 = | RESULT float32 4:2x8x64x1024 EJHL Transpose /attention/Transpose_3_output_0 | RESULT float32 4:2x8x64x1024 EJHL Transpose transpose_3__1
039 + | | RESULT float32 4:1x1x1024x64 GSEC Transpose unsqueeze_1__1
040 = | RESULT float32 3:2x1024x512 OSYT MatMul /attention/q_proj/MatMul_output_ | RESULT float32 3:2x1024x512 OSYT MatMul attention_q_proj_1__1
041 = | RESULT float32 4:2x1024x8x64 OSYT Reshape /attention/Reshape_output_0 | RESULT float32 4:2x1024x8x64 OSYT Reshape view_6__1
042 = | RESULT float32 4:2x8x1024x64 HAKH Transpose /attention/Transpose_output_0 | RESULT float32 4:2x8x1024x64 HAKH Transpose transpose__1
043 = | RESULT float32 4:2x8x1024x32 EVBF Split /attention/Slice_output_0 | RESULT float32 4:2x8x1024x32 EVBF Split slice_3__1
044 = | RESULT float32 4:2x8x1024x32 CEID Split /attention/Slice_1_output_0 | RESULT float32 4:2x8x1024x32 CEID Split slice_4__1
045 = | RESULT float32 4:2x8x1024x32 YWSX Neg /attention/Neg_output_0 | RESULT float32 4:2x8x1024x32 YWSX Neg neg__1
046 = | RESULT float32 4:2x8x1024x64 DSTB Concat /attention/Concat_output_0 | RESULT float32 4:2x8x1024x64 DSTB Concat cat__1
047 = | RESULT float32 4:2x8x1024x64 NHCJ Mul /attention/Mul_1_output_0 | RESULT float32 4:2x8x1024x64 NHCJ Mul mul_1__1
048 + | | RESULT float32 4:1x1x1024x64 CJYF Transpose unsqueeze__1
049 = | RESULT float32 4:2x8x1024x64 IUYZ Mul /attention/Mul_output_0 | RESULT float32 4:2x8x1024x64 IUYZ Mul mul__1
050 = | RESULT float32 4:2x8x1024x64 VCBI Add /attention/Add_output_0 | RESULT float32 4:2x8x1024x64 VCBI Add add__1
051 = | RESULT float32 4:2x8x1024x1024 AWFA FusedMatMul /attention/Div_output_0 | RESULT float32 4:2x8x1024x1024 AWFA FusedMatMul div__1
052 + | | RESULT float32 4:2x1x1024x1024 AAAA Mul other_1__4
053 = | RESULT float32 4:2x8x1024x1024 AWFA Add /attention/Add_2_output_0 | RESULT float32 4:2x8x1024x1024 AWFA Add add_2__1
054 = | RESULT float32 4:2x8x1024x1024 NNON Softmax /attention/Softmax_output_0 | RESULT float32 4:2x8x1024x1024 NNON Softmax _softmax__1
055 = | RESULT float32 3:2x1024x512 HMLX MatMul /attention/v_proj/MatMul_output_ | RESULT float32 3:2x1024x512 HMLX MatMul attention_v_proj_1__1
056 = | RESULT float32 4:2x1024x8x64 HMLX Reshape /attention/Reshape_2_output_0 | RESULT float32 4:2x1024x8x64 HMLX Reshape view_8__1
057 = | RESULT float32 4:2x8x1024x64 FOKY Transpose /attention/Transpose_2_output_0 | RESULT float32 4:2x8x1024x64 FOKY Transpose transpose_2__1
058 = | RESULT float32 4:2x8x1024x64 PLNS MatMul /attention/MatMul_1_output_0 | RESULT float32 4:2x8x1024x64 PLNS MatMul view_14__1
059 = | RESULT float32 4:2x1024x8x64 BZDC Transpose /attention/Transpose_4_output_0 | RESULT float32 4:2x1024x8x64 BZDC Transpose transpose_4__1
060 = | RESULT float32 3:2x1024x512 BZDC Reshape /attention/Reshape_3_output_0 | RESULT float32 3:2x1024x512 BZDC Reshape view_15__1
061 = | RESULT float32 3:2x1024x512 OPNS MatMul 130 | RESULT float32 3:2x1024x512 OPNS MatMul attention_1
062 = | OUTPUT float32 3:2x1024x512 OPNS 130 | OUTPUT float32 3:2x1024x512 OPNS attention_1
See plot_llama_diff_export for a better view.
Total running time of the script: (0 minutes 31.156 seconds)