Note
Go to the end to download the full example code.
301: Compares LLAMA exporters¶
The script compares the two exporters implemented in pytorch for a part of llama model. The model are compared after all optimizations were made with and onnxruntime.
TorchScript-based ONNX Exporter, let’s call it script
TorchDynamo-based ONNX Exporter, let’s call it dynamo
To run the script:
python _doc/examples/plot_llama_diff_export --help
Some helpers¶
from experimental_experiment.args import get_parsed_args
script_args = get_parsed_args(
"plot_llama_diff_export",
description=__doc__,
part=("attention", "one value among attention, decoder, model"),
exporter=("dynamo", "one value among dynamo, custom"),
ortopt=(1, "run onnxruntime optimization"),
opset=(18, "onnx opset"),
expose="part,exporter,ortopt,opset",
)
import contextlib
import os
import io
import warnings
import logging
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import onnxruntime
has_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers()
except ImportError:
print("onnxruntime not available.")
import sys
sys.exit(0)
import numpy as np
import onnx
from onnx_array_api.reference import compare_onnx_execution, ExtendedReferenceEvaluator
import torch
from experimental_experiment.ext_test_case import unit_test_going
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.xbuilder import OptimizationOptions
from experimental_experiment.convert.convert_helper import (
optimize_model_proto_oxs,
ort_optimize,
)
from experimental_experiment.torch_models.llama_helper import (
get_llama_model,
get_llama_attention,
get_llama_decoder,
)
from experimental_experiment.torch_models.dump_helper import reorder_functions_in_proto
has_cuda = has_cuda and torch.cuda.is_available()
logging.disable(logging.ERROR)
provider = "cuda" if has_cuda else "cpu"
The exporting functions¶
print(f"part={script_args.part}")
print(f"exporter={script_args.exporter}")
ortopt = script_args.ortopt in (1, "1")
print(f"ortopt={ortopt}")
opset = int(script_args.opset)
print(f"opset={opset}")
def opt_filename(filename: str) -> str:
name, ext = os.path.splitext(filename)
return f"{name}.opt{ext}"
def export_script(filename, model, *args):
with contextlib.redirect_stdout(io.StringIO()):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
torch.onnx.export(
model, args, filename, input_names=["input"], opset_version=opset
)
if ortopt:
onx = onnx.load(filename)
ort_optimize(onx, opt_filename(filename), providers=provider)
def export_dynamo(filename, model, *args):
with contextlib.redirect_stdout(io.StringIO()):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
export_output = torch.onnx.export(model, args, dynamo=True)
model = export_output.model_proto
try:
new_model = optimize_model_proto_oxs(model)
except ImportError as e:
print("skipping optimization, missing package or failure:", e)
new_model = model
with open(filename, "wb") as f:
f.write(new_model.SerializeToString())
if ortopt:
ort_optimize(new_model, opt_filename(filename), providers=provider)
def export_custom(filename, model, *args):
new_model = to_onnx(
model,
tuple(args),
input_names=[f"input{i}" for i in range(len(args))],
options=OptimizationOptions(
remove_unused=True,
constant_folding=False,
),
target_opset=opset,
)
with open(filename, "wb") as f:
f.write(new_model.SerializeToString())
if ortopt:
ort_optimize(new_model, opt_filename(filename), providers=provider)
part=attention
exporter=dynamo
ortopt=True
opset=18
Model and data¶
if unit_test_going():
kwargs = dict(input_dims=[(2, 1024)] * 2)
else:
kwargs = dict(
input_dims=[(2, 1024)] * 2,
_attn_implementation="eager",
num_hidden_layers=1,
hidden_size=512,
vocab_size=4000,
intermediate_size=2000,
max_position_embeddings=2048,
num_attention_heads=8,
)
if script_args.part == "attention":
model, inputs = get_llama_attention(**kwargs)
elif script_args.part == "decoder":
model, inputs = get_llama_decoder(**kwargs)
elif script_args.part == "model":
model, inputs = get_llama_model(**kwargs)
else:
raise RuntimeError(f"Unexpected value for part={script_args.part!r}")
print(f"simple run with {len(inputs)} inputs")
expected = model(*inputs[0])
if isinstance(expected, tuple):
for t in expected:
if not isinstance(t, tuple):
print(f"eager worked {t.shape}, {t.dtype}")
else:
print(f"eager worked {type(t)}")
else:
print(f"eager mode worked {expected.shape}, {expected.dtype}")
simple run with 2 inputs
eager mode worked torch.Size([2, 1024, 512]), torch.float32
Exporting¶
exporter = script_args.exporter
file1 = f"llama.{script_args.part}.script.onnx"
file2 = f"llama.{script_args.part}.{exporter}.onnx"
print("torch script exporter")
export_script(file1, model, *inputs[0])
if exporter == "dynamo":
print("torch dynamo exporter")
export_dynamo(file2, model, *inputs[0])
elif exporter == "custom":
print("torch custom exporter")
export_custom(file2, model, *inputs[0])
else:
raise AssertionError(f"Unexpected value for exporter={exporter!r}.")
torch script exporter
torch dynamo exporter
Applied 7 of general pattern rewrite rules.
Verification¶
if ortopt:
print("Using models optimized by onnxruntime")
file1 = f"llama.{script_args.part}.script.opt.onnx"
file2 = f"llama.{script_args.part}.{exporter}.opt.onnx"
providers = (
["CPUExecutionProvider"]
if provider == "cpu"
else [("CUDAExecutionProvider", {}), ("CPUExecutionProvider", {})]
)
model1 = onnx.load(file1)
model2 = onnx.load(file2)
feeds1, feeds2 = {}, {}
for i in range(len(inputs[0])):
x = inputs[0][i].detach().numpy()
feeds1[model1.graph.input[i].name] = x
feeds2[model2.graph.input[i].name] = x
if ortopt:
sess1 = onnxruntime.InferenceSession(file1, providers=providers)
sess2 = onnxruntime.InferenceSession(file2, providers=providers)
got1 = sess1.run(None, feeds1)
got2 = sess2.run(None, feeds2)
diff1 = np.abs(expected.detach().numpy() - got1[0]).max()
diff2 = np.abs(expected.detach().numpy() - got2[0]).max()
print(f"Error with the eager model and onnxruntime: {diff1}, {diff2}")
Using models optimized by onnxruntime
Error with the eager model and onnxruntime: 2.9178336262702942e-05, 2.9178336262702942e-05
Verification with the reference evaluator¶
reorder_functions_in_proto(file1)
reorder_functions_in_proto(file2)
sess1 = ExtendedReferenceEvaluator(file1)
try:
sess2 = ExtendedReferenceEvaluator(file2)
except NotImplementedError as e:
print(e)
sess2 = None
got1 = sess1.run(None, feeds1)
got2 = got1 if sess2 is None else sess2.run(None, feeds2)
if isinstance(expected, tuple):
diff1 = np.abs(expected[0].detach().numpy() - got1[0]).max()
diff2 = np.abs(expected[0].detach().numpy() - got2[0]).max()
else:
diff1 = np.abs(expected.detach().numpy() - got1[0]).max()
diff2 = np.abs(expected.detach().numpy() - got2[0]).max()
print(f"Error with the eager model and the reference evaluator: {diff1}, {diff2}")
Error with the eager model and the reference evaluator: 4.0978193283081055e-08, 4.0978193283081055e-08
Comparison and execution¶
def clean_name(name):
return name.replace(
"_inlfunc_transformers_models_llama_modeling_llama_LlamaAttention", ""
).replace("_inlfunc_torch_nn_modules_linear_Linear", "")
if sess2 is not None:
try:
np_inputs = [i.detach().numpy() for i in inputs[0]]
res1, res2, align, dc = compare_onnx_execution(
model1, model2, inputs=np_inputs, verbose=1, raise_exc=False
)
for r in res2:
r.name = clean_name(r.name)
text = dc.to_str(res1, res2, align, column_size=90)
print(text)
except AssertionError as e:
if "Unexpected type <class 'list'> for value, it must be a numpy array." not in str(e):
raise
print(e)
[compare_onnx_execution] execute with 3 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 60 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 60 results (first model)
[compare_onnx_execution] got 56 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 61 pairs
[compare_onnx_execution] done
001 = | INITIA float32 2:512x512 ZFXE onnx::MatMul_171 | INITIA float32 2:512x512 ZFXE t
002 = | INITIA float32 2:512x512 XETY onnx::MatMul_172 | INITIA float32 2:512x512 XETY t_1
003 = | INITIA float32 2:512x512 LEXW onnx::MatMul_173 | INITIA float32 2:512x512 LEXW t_2
004 - | INITIA float32 2:512x512 BTJW onnx::MatMul_219 |
005 = | INITIA int64 1:2 GGAA splits | INITIA int64 1:2 GGAA splits_token_14
006 - | INITIA int64 1:1 BAAA /attention/Constant_25_output_0 |
007 = | INITIA int64 1:4 CKIM /attention/Constant_2_output_0 | INITIA int64 1:4 CKIM val_2
008 + | | INITIA float32 2:512x512 BTJW t_3
009 ~ | INITIA int64 1:1 AAAA /attention/Constant_6_output_0 | INITIA int64 BAAA node_aten_unsqueeze_46_dim_0
010 = | INITIA float32 3:1x32x1 DAAA /attention/rotary_emb/Expand_out | INITIA float32 3:1x32x1 DAAA _to_copy_2
011 - | INITIA int64 1:1 KAAA /attention/Constant_24_output_0 |
012 = | INITIA int64 1:2 GGAA splits_token_14 | INITIA int64 1:2 GGAA splits
013 - | INITIA int64 1:1 DAAA const_transpose_optimizer_token_ |
014 = | INITIA int64 1:3 CKZA /attention/Constant_26_output_0 | INITIA int64 1:3 CKZA val_115
015 = | INPUT float32 3:2x1024x512 HVJE input | INPUT float32 3:2x1024x512 HVJE hidden_states
016 = | INPUT float32 4:2x1x1024x1024 AAAA onnx::Slice_1 | INPUT float32 4:2x1x1024x1024 AAAA attention_mask
017 = | INPUT int64 2:1x1024 KAQG onnx::Unsqueeze_2 | INPUT int64 2:1x1024 KAQG position_ids
018 = | RESULT int64 3:1x1x1024 KAQG Unsqueeze /attention/rotary_emb/Unsqueeze_ | RESULT int64 3:1x1x1024 KAQG Unsqueeze unsqueeze_2
019 = | RESULT float32 3:1x1x1024 KAQG Cast /attention/rotary_emb/Cast_outpu | RESULT float32 3:1x1x1024 KAQG Cast _to_copy_1
020 = | RESULT float32 3:1x32x1024 EFXM MatMul /attention/rotary_emb/MatMul_out | RESULT float32 3:1x32x1024 EFXM MatMul matmul_3
021 = | RESULT float32 3:1x64x1024 JKJK Concat /attention/rotary_emb/Concat | RESULT float32 3:1x64x1024 JKJK Concat node_Concat_64
022 = | RESULT float32 3:1x64x1024 RMRM Sin /attention/rotary_emb/Sin | RESULT float32 3:1x64x1024 RMRM Sin node_Sin_66
023 = | RESULT float32 4:1x1x64x1024 RMRM Unsqueeze /attention/Unsqueeze_1 | RESULT float32 4:1x1x64x1024 RMRM Unsqueeze node_aten_unsqueeze_73_n2
024 = | RESULT float32 4:1x1024x1x64 GSEC Transpose Transpose_token_7_out0 | RESULT float32 4:1x1024x1x64 GSEC Transpose Transpose_token_7_out0
025 = | RESULT float32 3:2x1024x512 CVTW MatMul /attention/k_proj/MatMul_output_ | RESULT float32 3:2x1024x512 CVTW MatMul matmul_1
026 = | RESULT float32 4:2x1024x8x64 CVTW Reshape /attention/Reshape_1_output_0 | RESULT float32 4:2x1024x8x64 CVTW Reshape view_1
027 = | RESULT float32 4:2x1024x8x32 YMXO Split /attention/Slice_2 | RESULT float32 4:2x1024x8x32 YMXO Split node_Slice_114
028 = | RESULT float32 4:2x1024x8x32 EIXJ Split /attention/Slice_3 | RESULT float32 4:2x1024x8x32 EIXJ Split node_Slice_125
029 = | RESULT float32 4:2x1024x8x32 WSDR Neg /attention/Neg_1 | RESULT float32 4:2x1024x8x32 WSDR Neg node_aten_neg_126_n0
030 = | RESULT float32 4:2x1024x8x64 TDAE Concat /attention/Concat_1 | RESULT float32 4:2x1024x8x64 TDAE Concat node_Concat_127
031 = | RESULT float32 4:2x1024x8x64 KRKR Mul /attention/Mul_3 | RESULT float32 4:2x1024x8x64 KRKR Mul node_Mul_128
032 = | RESULT float32 3:1x64x1024 NHNH Cos /attention/rotary_emb/Cos | RESULT float32 3:1x64x1024 NHNH Cos node_Cos_65
033 = | RESULT float32 4:1x1x64x1024 NHNH Unsqueeze /attention/Unsqueeze | RESULT float32 4:1x1x64x1024 NHNH Unsqueeze node_aten_unsqueeze_72_n2
034 = | RESULT float32 4:1x1024x1x64 CJYF Transpose Transpose_token_11_out0 | RESULT float32 4:1x1024x1x64 CJYF Transpose Transpose_token_11_out0
035 = | RESULT float32 4:2x1024x8x64 EPKU Mul /attention/Mul_2 | RESULT float32 4:2x1024x8x64 EPKU Mul node_Mul_103
036 = | RESULT float32 4:2x1024x8x64 OFUL Add /attention/Add_1 | RESULT float32 4:2x1024x8x64 OFUL Add node_Add_129
037 = | RESULT float32 4:2x8x64x1024 LILV Transpose /attention/Transpose_3_output_0 | RESULT float32 4:2x8x64x1024 LILV Transpose transpose_4
038 = | RESULT float32 4:1x1x1024x64 GSEC Transpose /attention/Unsqueeze_1_output_0 | RESULT float32 4:1x1x1024x64 GSEC Transpose unsqueeze_4
039 = | RESULT float32 3:2x1024x512 QPPG MatMul /attention/q_proj/MatMul_output_ | RESULT float32 3:2x1024x512 QPPG MatMul matmul
040 = | RESULT float32 4:2x1024x8x64 QPPG Reshape /attention/Reshape_output_0 | RESULT float32 4:2x1024x8x64 QPPG Reshape view
041 = | RESULT float32 4:2x8x1024x64 MTAU Transpose /attention/Transpose_output_0 | RESULT float32 4:2x8x1024x64 MTAU Transpose transpose
042 = | RESULT float32 4:2x8x1024x32 YZNT Split /attention/Slice_output_0 | RESULT float32 4:2x8x1024x32 YZNT Split slice_4
043 = | RESULT float32 4:2x8x1024x32 PUOB Split /attention/Slice_1_output_0 | RESULT float32 4:2x8x1024x32 PUOB Split slice_5
044 = | RESULT float32 4:2x8x1024x32 LGMZ Neg /attention/Neg_output_0 | RESULT float32 4:2x8x1024x32 LGMZ Neg neg
045 = | RESULT float32 4:2x8x1024x64 KGYT Concat /attention/Concat_output_0 | RESULT float32 4:2x8x1024x64 KGYT Concat cat_1
046 = | RESULT float32 4:2x8x1024x64 LNZK Mul /attention/Mul_1_output_0 | RESULT float32 4:2x8x1024x64 LNZK Mul mul_3
047 = | RESULT float32 4:1x1x1024x64 CJYF Transpose /attention/Unsqueeze_output_0 | RESULT float32 4:1x1x1024x64 CJYF Transpose unsqueeze_3
048 = | RESULT float32 4:2x8x1024x64 QEUP Mul /attention/Mul_output_0 | RESULT float32 4:2x8x1024x64 QEUP Mul mul_2
049 = | RESULT float32 4:2x8x1024x64 CQTA Add /attention/Add_output_0 | RESULT float32 4:2x8x1024x64 CQTA Add add
050 = | RESULT float32 4:2x8x1024x1024 CYLH FusedMatMul /attention/Div_output_0 | RESULT float32 4:2x8x1024x1024 CYLH FusedMatMul div
051 - | RESULT float32 4:2x1x1024x1024 AAAA Slice /attention/Slice_4_output_0 |
052 = | RESULT float32 4:2x8x1024x1024 CYLH Add /attention/Add_2_output_0 | RESULT float32 4:2x8x1024x1024 CYLH Add add_2
053 = | RESULT float32 4:2x8x1024x1024 OOOO Softmax /attention/Softmax_output_0 | RESULT float32 4:2x8x1024x1024 OOOO Softmax val_113
054 = | RESULT float32 3:2x1024x512 OOUT MatMul /attention/v_proj/MatMul_output_ | RESULT float32 3:2x1024x512 OOUT MatMul matmul_2
055 = | RESULT float32 4:2x1024x8x64 OOUT Reshape /attention/Reshape_2_output_0 | RESULT float32 4:2x1024x8x64 OOUT Reshape view_2
056 = | RESULT float32 4:2x8x1024x64 MSZO Transpose /attention/Transpose_2_output_0 | RESULT float32 4:2x8x1024x64 MSZO Transpose transpose_2
057 = | RESULT float32 4:2x8x1024x64 KPAJ MatMul /attention/MatMul_1_output_0 | RESULT float32 4:2x8x1024x64 KPAJ MatMul matmul_5
058 = | RESULT float32 4:2x1024x8x64 AZHC Transpose /attention/Transpose_4_output_0 | RESULT float32 4:2x1024x8x64 AZHC Transpose transpose_5
059 = | RESULT float32 3:2x1024x512 AZHC Reshape /attention/Reshape_3_output_0 | RESULT float32 3:2x1024x512 AZHC Reshape view_3
060 = | RESULT float32 3:2x1024x512 LPNQ MatMul 170 | RESULT float32 3:2x1024x512 LPNQ MatMul matmul_6
061 = | OUTPUT float32 3:2x1024x512 LPNQ 170 | OUTPUT float32 3:2x1024x512 LPNQ matmul_6
See plot_llama_diff_export for a better view.
Total running time of the script: (0 minutes 5.316 seconds)