Note
Go to the end to download the full example code.
301: Compares LLAMA exporters for onnxrt backend¶
The script compares exported models in pytorch using onnxrt backend. It tries to do a side by side of the execution of both models.
To run the script:
python _doc/examples/plot_llama_diff_dort --help
The following example compares the forward step for mixed precision on cuda and produces all the intermediate onnx graphs.
You may use --mixed=1
to compare the backward graphs.
Some helpers¶
from experimental_experiment.args import get_parsed_args
script_args = get_parsed_args(
"plot_llama_diff_export",
description=__doc__,
part=("model", "one value among model, ..."),
ortopt=(1, "run onnxruntime optimization"),
backward=(0, "does one operator for backward"),
cuda=(0, "use cuda or not"),
mixed=(0, "use miwed precision"),
opset=(18, "onnx opset"),
expose="part,exporter,ortopt,cuda,mixed,opset",
)
import copy
import os
import warnings
import logging
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import onnxruntime
has_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers()
except ImportError:
print("onnxruntime not available.")
import sys
sys.exit(0)
import onnx
from onnx_array_api.reference import compare_onnx_execution, ExtendedReferenceEvaluator
import torch
from torch._dynamo.backends.common import aot_autograd
from experimental_experiment.ext_test_case import unit_test_going
from experimental_experiment.convert.convert_helper import (
ort_optimize,
optimize_model_proto_oxs,
)
from experimental_experiment.torch_models.llama_helper import get_llama_model
from experimental_experiment.torch_models.dump_helper import (
assert_all_close,
dump_onnx,
reorder_functions_in_proto,
inputs_from_onnx_model,
build_matching_inputs,
results_to_string,
)
from experimental_experiment.torch_models.training_helper import (
train_loop,
make_aot_ort,
)
from experimental_experiment.torch_dynamo import (
onnx_debug_backend,
get_decomposition_table,
)
has_cuda = has_cuda and torch.cuda.device_count() > 0
logging.disable(logging.ERROR)
provider = "cuda" if has_cuda else "cpu"
The exporting functions¶
print(f"part={script_args.part}")
ortopt = script_args.ortopt in (1, "1")
print(f"ortopt={ortopt}")
backward = script_args.backward in (1, "1")
print(f"backward={backward}")
use_cuda = script_args.cuda in (1, "1")
print(f"cuda={use_cuda}")
use_mixed = script_args.mixed in (1, "1")
print(f"mixed={use_mixed}")
opset = int(script_args.opset)
print(f"opset={opset}")
part=model
ortopt=True
backward=False
cuda=False
mixed=False
opset=18
Model and data¶
if unit_test_going():
kwargs = dict(input_dims=[(2, 1024)] * 2)
else:
kwargs = dict(
input_dims=[(2, 1024)] * 2,
_attn_implementation="eager",
num_hidden_layers=1,
hidden_size=512,
vocab_size=4000,
intermediate_size=2000,
max_position_embeddings=2048,
num_attention_heads=8,
)
if script_args.part == "model":
model, inputs = get_llama_model(**kwargs)
else:
raise RuntimeError(f"Unexpected value for part={script_args.part!r}")
if use_cuda:
model = model.to("cuda")
inputs = [[i.to("cuda") for i in inp] for inp in inputs]
print(f"simple run with {len(inputs)} inputs")
if backward:
if use_mixed:
assert use_cuda, "mixed precision only works with cuda"
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch.cuda.synchronize()
expected = train_loop(copy.deepcopy(model), *inputs[0])
torch.cuda.synchronize()
else:
expected = train_loop(copy.deepcopy(model), *inputs[0])
print(
f"-- eager mode worked, {len(expected)} gradients, first one is "
f"{expected[0].shape}, {expected[0].dtype}"
)
else:
if use_mixed:
assert use_cuda, "mixed precision only works with cuda"
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch.cuda.synchronize()
expected = model(*inputs[0])
torch.cuda.synchronize()
else:
expected = model(*inputs[0])
print(results_to_string(expected))
simple run with 2 inputs
1 results
torch.float32 (2, 1024, 512) [sum=1.39e+04]
Exporting¶
if hasattr(torch._dynamo.variables.misc, "LoggingLoggerVariable"):
# A tweak to make torch.export.export work.
torch._dynamo.variables.misc.LoggingLoggerVariable.call_method = lambda *_, **__: None
folder = "dump_models"
storage = {}
if backward:
# onnxrt backend
local_aot_ort, _ = make_aot_ort(dynamic=False, rewrite=True)
optimized_mod = torch.compile(
copy.deepcopy(model), backend=local_aot_ort, dynamic=False, fullgraph=True
)
with dump_onnx("llama_onnxrt", folder=folder, clean=True):
if use_mixed:
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch.cuda.synchronize()
expected_onnxrt = train_loop(optimized_mod, *inputs[0])
torch.cuda.synchronize()
else:
expected_onnxrt = train_loop(optimized_mod, *inputs[0])
assert_all_close(expected[0], expected_onnxrt[0], atol=1e-3)
print(
f"-- onnxrt backend worked, {len(expected_onnxrt)} gradients, first one is "
f"{expected_onnxrt[0].shape}, {expected_onnxrt[0].dtype}"
)
# debugging backend
aot_compiler = aot_autograd(
fw_compiler=lambda *args, **kwargs: onnx_debug_backend(
*args,
dump_prefix=os.path.join(folder, "llama_debug"),
target_opset=opset,
storage=storage,
**kwargs,
),
decompositions=get_decomposition_table(),
)
onnx_mod = torch.compile(copy.deepcopy(model), backend=aot_compiler, fullgraph=True)
if use_mixed:
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch.cuda.synchronize()
got = train_loop(onnx_mod, *inputs[0])
torch.cuda.synchronize()
else:
got = train_loop(onnx_mod, *inputs[0])
assert_all_close(expected[0], got[0], atol=1e-2 if use_mixed else 1e-4)
print(
f"-- debug backend worked, {len(got)} gradients, first one is "
f"{got[0].shape}, {got[0].dtype}"
)
else:
# onnxrt backend
local_aot_ort, _ = make_aot_ort(dynamic=True, rewrite=True)
optimized_mod = torch.compile(model, backend=local_aot_ort, fullgraph=True)
with dump_onnx("llama_onnxrt", folder=folder, clean=True):
if use_mixed:
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch.cuda.synchronize()
expected_onnxrt = optimized_mod(*inputs[0])
torch.cuda.synchronize()
else:
expected_onnxrt = optimized_mod(*inputs[0])
assert_all_close(expected, expected_onnxrt, atol=1e-2)
# debugging backend
aot_compiler = aot_autograd(
fw_compiler=lambda *args, **kwargs: onnx_debug_backend(
*args,
dump_prefix=os.path.join(folder, "llama_debug"),
target_opset=17,
storage=storage,
**kwargs,
)
)
onnx_mod = torch.compile(model, backend=aot_compiler, fullgraph=True)
if use_mixed:
with torch.autocast(device_type="cuda", dtype=torch.float16):
got = onnx_mod(*inputs[0])
else:
try:
got = onnx_mod(*inputs[0])
except Exception as e:
print(f"ERROR: {e}")
got = None
if got is not None:
assert_all_close(expected, got, atol=1 if use_mixed else 1e-3)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/onnx/_internal/_exporter_legacy.py:109: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
warnings.warn(
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/onnx/_internal/fx/onnxfunction_dispatcher.py:505: FutureWarning: 'onnxscript.values.TracedOnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
self.param_schema = self.onnxfunction.param_schemas()
For forward, there are two files, one onnx model and the graph module printed in a txt file. For backward, there are two onnx models. Then it is multiplied by the number of backends.
models = os.listdir(folder)
print(f"exported models: {models}")
exported models: ['llama_debug_0.onnx', 'llama_onnxrt_0.txt', 'llama_debug_0.txt', 'llama_onnxrt_0.onnx']
Inputs used by the debug backend
-- input0 int64 (2, 1024)
-- input1 float32 (4000, 512)
-- input2 float32 (2, 1024)
-- input3 float32 (32,)
-- input4 float32 (512,)
-- input5 float32 (512, 512)
-- input6 float32 (512, 512)
-- input7 float32 (512, 512)
-- input8 float32 (512, 512)
-- input9 float32 (512,)
-- input10 float32 (2000, 512)
-- input11 float32 (2000, 512)
-- input12 float32 (512, 2000)
-- input13 float32 (512,)
Let’s the first line of the graph module
if "instance" in storage:
graph_module = storage["instance"][0]["graph_module"]
print("\n".join(str(graph_module.graph).split("\n")[:10]))
graph():
%primals_1 : [num_users=2] = placeholder[target=primals_1]
%primals_2 : [num_users=1] = placeholder[target=primals_2]
%primals_3 : [num_users=1] = placeholder[target=primals_3]
%primals_4 : [num_users=1] = placeholder[target=primals_4]
%primals_5 : [num_users=2] = placeholder[target=primals_5]
%primals_6 : [num_users=1] = placeholder[target=primals_6]
%primals_7 : [num_users=1] = placeholder[target=primals_7]
%primals_8 : [num_users=1] = placeholder[target=primals_8]
%primals_9 : [num_users=1] = placeholder[target=primals_9]
Comparison and execution¶
if "instance" in storage:
if backward:
print(f"-- {len(storage['instance'])} onnx models were creates")
for i, inst in enumerate(storage["instance"]):
print(f" model {i}: {len(inst['inputs'])} runs")
# deal with backward
onnx_models = list(sorted([m for m in models if m.endswith(".onnx")]))
assert len(onnx_models) == 4, f"unexpected value {onnx_models}"
onnx_models = list(sorted([m for m in models if m.endswith(".onnx") and "_1" in m]))
assert len(onnx_models) == 2, f"unexpected value {onnx_models}"
model_onnxrt = os.path.join(folder, onnx_models[1])
model_debug = os.path.join(folder, onnx_models[0])
else:
onnx_models = list(sorted([m for m in models if m.endswith(".onnx")]))
if len(onnx_models) == 2:
model_onnxrt = os.path.join(folder, onnx_models[1])
model_debug = os.path.join(folder, onnx_models[0])
else:
model_debug = os.path.join(folder, onnx_models[0])
# the following error may appear:
# Node type 'Rank' from domain 'pkg.onnxscript.torch_lib.common' is unknown
print(f"One model is missing, onnx_models={onnx_models}")
model_onnxrt = model_debug
print(f"model_onnxrt={model_onnxrt}")
print(f"model_debug={model_debug}")
model_onnxrt=dump_models/llama_onnxrt_0.onnx
model_debug=dump_models/llama_debug_0.onnx
The inputs of both models
if "instance" in storage:
print("onnxrt:", inputs_from_onnx_model(model_onnxrt))
print("debug:", inputs_from_onnx_model(model_debug))
onnxrt: [('INPUT', 'primals_2', 1, (4000, 512)), ('INPUT', 'primals_1', 7, (2, 1024)), ('INPUT', 'primals_3', 1, (2, 1024)), ('INPUT', 'primals_4', 1, (32,)), ('INPUT', 'primals_6', 1, (512, 512)), ('INPUT', 'primals_7', 1, (512, 512)), ('INPUT', 'primals_8', 1, (512, 512)), ('INPUT', 'primals_9', 1, (512, 512)), ('INPUT', 'primals_11', 1, (2000, 512)), ('INPUT', 'primals_12', 1, (2000, 512)), ('INPUT', 'primals_13', 1, (512, 2000)), ('INPUT', 'primals_5', 1, (512,)), ('INPUT', 'primals_10', 1, (512,)), ('INPUT', 'primals_14', 1, (512,))]
debug: [('INPUT', 'input0', 7, (2, 1024)), ('INPUT', 'input1', 1, (4000, 512)), ('INPUT', 'input2', 1, (2, 1024)), ('INPUT', 'input3', 1, (32,)), ('INPUT', 'input4', 1, (512,)), ('INPUT', 'input5', 1, (512, 512)), ('INPUT', 'input6', 1, (512, 512)), ('INPUT', 'input7', 1, (512, 512)), ('INPUT', 'input8', 1, (512, 512)), ('INPUT', 'input9', 1, (512,)), ('INPUT', 'input10', 1, (2000, 512)), ('INPUT', 'input11', 1, (2000, 512)), ('INPUT', 'input12', 1, (512, 2000)), ('INPUT', 'input13', 1, (512,))]
Inputs are not the same. The first model has more and some inputs were moved into the initializer list into for model_debug.
if "instance" in storage:
print("debug:", inputs_from_onnx_model(model_debug, init=True))
debug: [('INPUT', 'input0', 7, (2, 1024)), ('INPUT', 'input1', 1, (4000, 512)), ('INPUT', 'input2', 1, (2, 1024)), ('INPUT', 'input3', 1, (32,)), ('INPUT', 'input4', 1, (512,)), ('INPUT', 'input5', 1, (512, 512)), ('INPUT', 'input6', 1, (512, 512)), ('INPUT', 'input7', 1, (512, 512)), ('INPUT', 'input8', 1, (512, 512)), ('INPUT', 'input9', 1, (512,)), ('INPUT', 'input10', 1, (2000, 512)), ('INPUT', 'input11', 1, (2000, 512)), ('INPUT', 'input12', 1, (512, 2000)), ('INPUT', 'input13', 1, (512,)), ('INIT', 'init7_s_0', 7, ()), ('INIT', 'init7_s_1024', 7, ()), ('INIT', 'init7_s_1', 7, ()), ('INIT', 'init7_s2_1024_1024', 7, (2,)), ('INIT', 'init7_s2_-1_1', 7, (2,)), ('INIT', 'init7_s1_1', 7, (1,)), ('INIT', 'init7_s4_2_1_1024_1024', 7, (4,)), ('INIT', 'init1_s_', 1, ()), ('INIT', 'init1_s1_', 1, (1,)), ('INIT', 'init1_s_2', 1, ()), ('INIT', 'init1_s1_2', 1, (1,)), ('INIT', 'init1_s_3', 1, ()), ('INIT', 'init7_s2_2048_512', 7, (2,)), ('INIT', 'init7_s3_2_1024_512', 7, (3,)), ('INIT', 'init7_s4_2_1024_-1_64', 7, (4,)), ('INIT', 'init7_s3_16_1024_64', 7, (3,)), ('INIT', 'init7_s3_16_64_1024', 7, (3,)), ('INIT', 'init1_s_4', 1, ()), ('INIT', 'init7_s3_16_1024_1024', 7, (3,)), ('INIT', 'init7_s3_2_1024_2000', 7, (3,)), ('INIT', 'init7_s2_2048_2000', 7, (2,)), ('INIT', 'init7_s2_0_1', 7, (2,)), ('INIT', 'init7_s2_1_2', 7, (2,)), ('INIT', 'init7_s2_0_2', 7, (2,)), ('INIT', 'init7_s2_32_32', 7, (2,))]
Optimization and Verification¶
Let’s try the model with a python backend (reference implementation). First step, onnxscript uses many functions. The reference evaluation expects every function to be defined so the order of functions in the model matters. No recursivity is allowed by this runtime. We need to reorder as function Rank is usually placed at the end of the model.
if "instance" in storage:
reorder_functions_in_proto(model_onnxrt)
Let’s load the model and optimize them.
if "instance" in storage:
debug = onnx.load(model_debug)
try:
onnxrt = optimize_model_proto_oxs(onnx.load(model_onnxrt))
except ImportError as e:
print("missing library", e)
onnxrt = debug
Applied 15 of general pattern rewrite rules.
Let’s apply onnxruntime optimization
if "instance" in storage and ortopt:
providers = (
[("CUDAExecutionProvider", {}), ("CPUExecutionProvider", {})]
if use_cuda
else ["CPUExecutionProvider"]
)
with open(model_onnxrt.replace(".onnx", ".before.opt.onnx"), "wb") as f:
f.write(onnxrt.SerializeToString())
print(f"run onnxruntime optimization on {model_onnxrt}")
optimized = model_onnxrt.replace(".onnx", ".opt.onnx")
ort_optimize(onnxrt, output=optimized, providers=providers)
onnxrt = onnx.load(optimized)
print(f"run onnxruntime optimization on {model_debug}")
optimized = model_debug.replace(".onnx", ".opt.onnx")
ort_optimize(debug, output=optimized, disable_aot=True, providers=providers)
debug = onnx.load(optimized)
run onnxruntime optimization on dump_models/llama_onnxrt_0.onnx
run onnxruntime optimization on dump_models/llama_debug_0.onnx
For what’s following, we need to build two lists of matching inputs.
if "instance" in storage:
print("build_matching_inputs")
feedsrt = build_matching_inputs(model_debug, feeds, model_onnxrt)
print("done")
build_matching_inputs
done
We check both models are running.
if "instance" in storage:
out_onnxrt = ExtendedReferenceEvaluator(onnxrt).run(None, feedsrt)
out_debug = ExtendedReferenceEvaluator(debug).run(None, feeds)
assert out_onnxrt
assert out_debug
# assert_all_close(out_onnxrt, out_debug)
Side by side
[compare_onnx_execution] execute with 2 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 177 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 177 results (first model)
[compare_onnx_execution] got 171 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 212 pairs
[compare_onnx_execution] done
001 ~ | INITIA float32 ?AAA _val_22 | INITIA float32 1:1 AAAA _reshape_init1_s_0
002 - | INITIA float32 4:1024x1x2x1024 ???? _val_415 |
003 - | INITIA int64 2:1x1 AAAA _val_436 |
004 ~ | INITIA int64 1:4 CIKK _val_503 | INITIA float32 1:1 AAAA _reshape_init1_s_303
005 ~ | INITIA int64 1:3 QMKA _val_464 | INITIA float32 3:1x1x1024 KAQG _to_copy
006 - | INITIA int64 BAAA dim_0__4 |
007 ~ | INITIA int64 1:3 CKSA _val_585 | INITIA int64 1:3 QKKA init7_s3_16_1024_1024
008 ~ | INITIA float32 4:1x2x1024x1024 ???? _val_438 | INITIA float32 4:2x1x1024x1024 ???? expand_1
009 ~ | INITIA int64 1:3 CKZA _val_544 | INITIA int64 1:1 BAAA init7_s1_1
010 - | INITIA int64 2:1024x1 KAQG _val_413 |
011 ~ | INITIA float32 AAAA scalar_tensor_default | INITIA int64 1:2 ACAA init7_s2_0_2
012 ~ | INITIA int64 1:2 UYAA _val_581 | INITIA int64 1:2 GGAA init7_s2_32_32
013 ~ | INITIA int64 AAAA dim_0__9 | INITIA float32 1:1 ?AAA init1_s1_
014 ~ | INITIA int64 1:1 ZAAA _val_592 | INITIA int64 1:2 BCAA init7_s2_1_2
015 - | INITIA float32 AAAA _val_522 |
016 ~ | INITIA int64 1:2 USAA _val_188 | INITIA float32 1:1 CAAA init1_s1_2
017 - | INITIA float32 3:1x1x1024 KAQG view_2 |
018 - | INITIA float32 4:2x1x1024x1024 ???? expand_1 |
019 ~ | INITIA int64 1:3 QKKA _val_533 | INITIA int64 1:2 UYAA init7_s2_2048_2000
020 ~ | INITIA int64 1:2 GGAA splits | INITIA int64 1:2 USAA init7_s2_2048_512
021 - | INITIA int64 CAAA _val_588 |
022 - | INITIA float32 AAAA _val_560 |
023 ~ | INITIA int64 1:4 CKZM _val_233 | INITIA int64 1:3 CKSA init7_s3_2_1024_512
024 ~ | INITIA int64 1:3 CKYA _val_572 | INITIA int64 1:4 CKZM init7_s4_2_1024_-1_64
025 ~ | INITIA int64 1:2 GGAA splits_token_16 | INITIA int64 1:3 QKMA init7_s3_16_1024_64
026 ~ | INITIA int64 1:4 CIKM _val_539 | INITIA int64 1:3 QMKA init7_s3_16_64_1024
027 ~ | INITIA int64 1:3 QKMA _val_377 | INITIA int64 1:3 CKYA init7_s3_2_1024_2000
028 + | | INPUT int64 2:2x1024 ZHFJ input0
029 = | INPUT float32 2:4000x512 AXYM primals_2 | INPUT float32 2:4000x512 AXYM input1
030 - | INPUT int64 2:2x1024 ZHFJ primals_1 |
031 = | INPUT float32 2:2x1024 BACA primals_3 | INPUT float32 2:2x1024 BACA input2
032 = | INPUT float32 1:32 DAAA primals_4 | INPUT float32 1:32 DAAA input3
033 + | | INPUT float32 1:512 YYYY input4
034 = | INPUT float32 2:512x512 EHWU primals_6 | INPUT float32 2:512x512 EHWU input5
035 = | INPUT float32 2:512x512 WEYY primals_7 | INPUT float32 2:512x512 WEYY input6
036 = | INPUT float32 2:512x512 CBSF primals_8 | INPUT float32 2:512x512 CBSF input7
037 = | INPUT float32 2:512x512 JXFC primals_9 | INPUT float32 2:512x512 JXFC input8
038 + | | INPUT float32 1:512 YYYY input9
039 = | INPUT float32 2:2000x512 KOCG primals_11 | INPUT float32 2:2000x512 KOCG input10
040 = | INPUT float32 2:2000x512 DKHP primals_12 | INPUT float32 2:2000x512 DKHP input11
041 = | INPUT float32 2:512x2000 STEQ primals_13 | INPUT float32 2:512x2000 STEQ input12
042 = | INPUT float32 1:512 YYYY primals_5 | INPUT float32 1:512 YYYY input13
043 ~ | INPUT float32 1:512 YYYY primals_10 | RESULT float32 1:512 YYYY Identity output_4
044 ~ | INPUT float32 1:512 YYYY primals_14 | RESULT float32 1:512 YYYY Identity output_3
045 - | RESULT float32 2:512x512 EHWU Identity t_33 |
046 - | RESULT float32 2:512x512 WEYY Identity t_29 |
047 - | RESULT float32 2:512x512 CBSF Identity t_25 |
048 - | RESULT float32 2:512x512 JXFC Identity t_21 |
049 - | RESULT float32 2:2000x512 KOCG Identity t_17 |
050 - | RESULT float32 2:2000x512 DKHP Identity t_13 |
051 - | RESULT float32 2:512x2000 STEQ Identity t_9 |
052 ~ | RESULT float32 2:1x32 DAAA Unsqueeze unsqueeze_7 | RESULT float32 1:512 YYYY Identity output_2
053 + | | RESULT int64 2:2x1024 ZHFJ Identity output_1
054 = | RESULT float32 3:1x32x1 DAAA Unsqueeze unsqueeze_8 | RESULT float32 3:1x32x1 DAAA Unsqueeze unsqueeze_8
055 = | RESULT float32 3:1x32x1024 EFXM MatMul view_3 | RESULT float32 3:1x32x1024 EFXM MatMul bmm
056 = | RESULT float32 3:1x64x1024 JKJK Concat Concat_392 | RESULT float32 3:1x64x1024 JKJK Concat cat_token_5
057 ~ | RESULT float32 3:1x1024x64 VFPY Transpose cat | RESULT float32 3:1x64x1024 RMRM Sin sin_token_7
058 ~ | RESULT float32 3:1x1024x64 GSEC Sin sin | RESULT float32 4:1x1x64x1024 RMRM Unsqueeze unsqueeze10
059 - | RESULT float32 4:1x1x1024x64 GSEC Unsqueeze unsqueeze_11 |
060 = | RESULT float32 4:1x1024x1x64 GSEC Transpose Transpose_token_5_out0 | RESULT float32 4:1x1024x1x64 GSEC Transpose Transpose_token_10_out0
061 = | RESULT float32 3:2x1024x512 BFOY Gather embedding | RESULT float32 3:2x1024x512 BFOY Gather output_5
062 = | RESULT float32 3:2x1024x512 BAAA Pow pow_1 | RESULT float32 3:2x1024x512 BAAA Pow pow_1
063 = | RESULT float32 3:2x1024x1 AAAA ReduceMean mean | RESULT float32 3:2x1024x1 AAAA ReduceMean mean
064 = | RESULT float32 3:2x1024x1 AAAA Add add_1 | RESULT float32 3:2x1024x1 AAAA Add add_1
065 = | RESULT float32 3:2x1024x1 KKKK Sqrt _val_139 | RESULT float32 3:2x1024x1 KKKK Sqrt _onx_sqrt_add_10
066 = | RESULT float32 3:2x1024x1 UTBY Reciprocal rsqrt | RESULT float32 3:2x1024x1 UTBY Reciprocal output_6
067 = | RESULT float32 3:2x1024x512 NBBI Mul mul_3 | RESULT float32 3:2x1024x512 NBBI Mul output_7
068 = | RESULT float32 3:2x1024x512 NBBI Mul mul_4 | RESULT float32 3:2x1024x512 NBBI Mul mul_4
069 = | RESULT float32 2:2048x512 NBBI Reshape view_4 | RESULT float32 2:2048x512 NBBI Reshape output_9
070 ~ | RESULT float32 2:2048x512 MTOA FusedMatMul mm_1 | RESULT float32 2:2048x512 MTOA Gemm mm_1
071 - | RESULT float32 3:2x1024x512 MTOA Reshape _unsafe_view_1 |
072 = | RESULT float32 4:2x1024x8x64 MTOA Reshape view_7 | RESULT float32 4:2x1024x8x64 MTOA Reshape view_7
073 = | RESULT float32 4:2x1024x8x32 IDGJ Split Slice_460 | RESULT float32 4:2x1024x8x32 IDGJ Split SlicesSplitPattern--slice_Tensor
074 = | RESULT float32 4:2x1024x8x32 FPIR Split Slice_477 | RESULT float32 4:2x1024x8x32 FPIR Split SlicesSplitPattern--slice_Tensor
075 = | RESULT float32 4:2x1024x8x32 VLSJ Neg Neg_500 | RESULT float32 4:2x1024x8x32 VLSJ Neg neg2
076 = | RESULT float32 4:2x1024x8x64 EOXS Concat Concat_508 | RESULT float32 4:2x1024x8x64 EOXS Concat cat3
077 = | RESULT float32 4:2x1024x8x64 TOQN Mul Mul_521 | RESULT float32 4:2x1024x8x64 TOQN Mul mul_Tensor15
078 + | | RESULT float32 3:1x64x1024 NHNH Cos cos_token_13
079 ~ | RESULT float32 3:1x1024x64 CJYF Cos cos | RESULT float32 4:1x1x64x1024 NHNH Unsqueeze unsqueeze9
080 - | RESULT float32 4:1x1x1024x64 CJYF Unsqueeze unsqueeze_10 |
081 = | RESULT float32 4:1x1024x1x64 CJYF Transpose Transpose_token_7_out0 | RESULT float32 4:1x1024x1x64 CJYF Transpose Transpose_token_16_out0
082 = | RESULT float32 4:2x1024x8x64 VBXL Mul Mul_519 | RESULT float32 4:2x1024x8x64 VBXL Mul mul_Tensor14
083 = | RESULT float32 4:2x1024x8x64 OPNZ Add Add_526 | RESULT float32 4:2x1024x8x64 OPNZ Add add_Tensor4
084 = | RESULT float32 4:2x8x64x1024 WHMZ Transpose transpose_4 | RESULT float32 4:2x8x64x1024 WHMZ Transpose transpose_4
085 + | | RESULT float32 4:1x1x1024x64 GSEC Transpose output_15
086 - | RESULT float32 3:16x64x1024 WHMZ Reshape _unsafe_view_4 |
087 ~ | RESULT float32 2:2048x512 YZMV FusedMatMul mm | RESULT float32 2:2048x512 YZMV Gemm mm
088 - | RESULT float32 3:2x1024x512 YZMV Reshape _unsafe_view |
089 = | RESULT float32 4:2x1024x8x64 YZMV Reshape view_5 | RESULT float32 4:2x1024x8x64 YZMV Reshape view_5
090 = | RESULT float32 4:2x8x1024x64 OIYJ Transpose transpose_1 | RESULT float32 4:2x8x1024x64 OIYJ Transpose transpose_1
091 = | RESULT float32 4:2x8x1024x32 CQTL Split slice_24 | RESULT float32 4:2x8x1024x32 CQTL Split slice_24
092 = | RESULT float32 4:2x8x1024x32 NTGY Split slice_25 | RESULT float32 4:2x8x1024x32 NTGY Split slice_25
093 = | RESULT float32 4:2x8x1024x32 NHUC Neg neg | RESULT float32 4:2x8x1024x32 NHUC Neg neg
094 = | RESULT float32 4:2x8x1024x64 PXNO Concat cat_1 | RESULT float32 4:2x8x1024x64 PXNO Concat cat_1
095 = | RESULT float32 4:2x8x1024x64 SQNQ Mul mul_6 | RESULT float32 4:2x8x1024x64 SQNQ Mul mul_6
096 + | | RESULT float32 4:1x1x1024x64 CJYF Transpose output_14
097 = | RESULT float32 4:2x8x1024x64 SVSM Mul mul_5 | RESULT float32 4:2x8x1024x64 SVSM Mul mul_5
098 = | RESULT float32 4:2x8x1024x64 KKGC Add add_2 | RESULT float32 4:2x8x1024x64 KKGC Add add_2
099 - | RESULT float32 3:16x1024x64 KKGC Reshape _unsafe_view_3 |
100 - | RESULT float32 3:16x1024x1024 UMRA MatMul bmm_1 |
101 - | RESULT float32 4:2x8x1024x1024 UMRA Reshape view_10 |
102 ~ | RESULT float32 4:2x8x1024x1024 SMYR Mul mul_9 | RESULT float32 4:2x8x1024x1024 SMYR FusedMatMul _onx_mul_view_100
103 - | RESULT float32 3:2x1x1024 BACA Unsqueeze unsqueeze_5 |
104 = | RESULT float32 4:2x1x1x1024 BACA Unsqueeze unsqueeze_6 | RESULT float32 4:2x1x1x1024 BACA Unsqueeze unsqueeze_6
105 = | RESULT float32 4:2x1x1024x1024 ???? Add add | RESULT float32 4:2x1x1024x1024 ???? Add add
106 = | RESULT bool 4:2x1x1024x1024 KWTE Equal eq | RESULT bool 4:2x1x1024x1024 KWTE Equal eq
107 = | RESULT float32 4:2x1x1024x1024 ???? Where masked_fill | RESULT float32 4:2x1x1024x1024 ???? Where masked_fill
108 - | RESULT float32 4:1024x1x2x1024 ???? Transpose _val_414 |
109 - | RESULT float32 4:1024x1x2x1024 ???? ScatterND _val_416 |
110 - | RESULT float32 4:1x2x1024x1024 ???? Transpose _val_437 |
111 - | RESULT float32 4:1x2x1024x1024 ???? ScatterND _val_439 |
112 - | RESULT float32 4:2x1x1024x1024 ???? Transpose slice_scatter_1 |
113 = | RESULT float32 4:2x8x1024x1024 ???? Add add_4 | RESULT float32 4:2x8x1024x1024 ???? Add add_4
114 = | RESULT float32 4:2x8x1024x1024 OOON Softmax _softmax | RESULT float32 4:2x8x1024x1024 OOON Softmax output_18
115 - | RESULT float32 3:16x1024x1024 OOON Reshape view_11 |
116 ~ | RESULT float32 2:2048x512 PTHN FusedMatMul mm_2 | RESULT float32 2:2048x512 PTHN Gemm mm_2
117 - | RESULT float32 3:2x1024x512 PTHN Reshape _unsafe_view_2 |
118 = | RESULT float32 4:2x1024x8x64 PTHN Reshape view_9 | RESULT float32 4:2x1024x8x64 PTHN Reshape view_9
119 = | RESULT float32 4:2x8x1024x64 AHLI Transpose transpose_3 | RESULT float32 4:2x8x1024x64 AHLI Transpose transpose_3
120 - | RESULT float32 3:16x1024x64 AHLI Reshape _unsafe_view_5 |
121 - | RESULT float32 3:16x1024x64 EFMR MatMul bmm_2 |
122 ~ | RESULT float32 4:2x8x1024x64 EFMR Reshape view_12 | RESULT float32 4:2x8x1024x64 EFMR MatMul view_12
123 = | RESULT float32 4:2x1024x8x64 EEPN Transpose transpose_5 | RESULT float32 4:2x1024x8x64 EEPN Transpose transpose_5
124 - | RESULT float32 3:2x1024x512 EEPN Reshape view_13 |
125 = | RESULT float32 2:2048x512 EEPN Reshape view_14 | RESULT float32 2:2048x512 EEPN Reshape output_22
126 ~ | RESULT float32 2:2048x512 RRNI FusedMatMul mm_3 | RESULT float32 2:2048x512 RRNI Gemm mm_3
127 = | RESULT float32 3:2x1024x512 RRNI Reshape _unsafe_view_6 | RESULT float32 3:2x1024x512 RRNI Reshape _unsafe_view_6
128 = | RESULT float32 3:2x1024x512 SWCG Add add_5 | RESULT float32 3:2x1024x512 SWCG Add output_23
129 = | RESULT float32 3:2x1024x512 DYPC Pow pow_2 | RESULT float32 3:2x1024x512 DYPC Pow pow_2
130 = | RESULT float32 3:2x1024x1 VVLL ReduceMean mean_1 | RESULT float32 3:2x1024x1 VVLL ReduceMean mean_1
131 = | RESULT float32 3:2x1024x1 VVLL Add add_6 | RESULT float32 3:2x1024x1 VVLL Add add_6
132 = | RESULT float32 3:2x1024x1 BBZZ Sqrt _val_562 | RESULT float32 3:2x1024x1 BBZZ Sqrt _onx_sqrt_add_60
133 = | RESULT float32 3:2x1024x1 GGHJ Reciprocal rsqrt_1 | RESULT float32 3:2x1024x1 GGHJ Reciprocal output_24
134 = | RESULT float32 3:2x1024x512 SIYF Mul mul_10 | RESULT float32 3:2x1024x512 SIYF Mul output_25
135 = | RESULT float32 3:2x1024x512 SIYF Mul mul_11 | RESULT float32 3:2x1024x512 SIYF Mul mul_11
136 = | RESULT float32 2:2048x512 SIYF Reshape view_15 | RESULT float32 2:2048x512 SIYF Reshape output_27
137 ~ | RESULT float32 2:2048x2000 VALD FusedMatMul mm_4 | RESULT float32 2:2048x2000 VALD Gemm mm_4
138 = | RESULT float32 3:2x1024x2000 VALD Reshape _unsafe_view_7 | RESULT float32 3:2x1024x2000 VALD Reshape output_28
139 ~ | RESULT float32 3:2x1024x2000 ZDZH QuickGelu silu | RESULT float32 3:2x1024x2000 DEHS Sigmoid _onx_sigmoid__unsafe_view_70
140 ~ | RESULT float32 2:2048x2000 EBAV FusedMatMul mm_5 | RESULT float32 2:2048x2000 DEHS Reshape Reshape2Of3PatternR__onx_sigmoid
141 ~ | RESULT float32 3:2x1024x2000 EBAV Reshape _unsafe_view_8 | RESULT float32 2:2048x2000 ZDZH Mul Reshape2Of3PatternL_output_29
142 ~ | RESULT float32 3:2x1024x2000 EWBL Mul mul_12 | RESULT float32 2:2048x2000 EBAV Gemm mm_5
143 ~ | RESULT float32 2:2048x2000 EWBL Reshape view_17 | RESULT float32 2:2048x2000 EWBL Mul output_34
144 ~ | RESULT float32 2:2048x512 SCST FusedMatMul mm_6 | RESULT float32 2:2048x512 SCST Gemm mm_6
145 = | RESULT float32 3:2x1024x512 SCST Reshape _unsafe_view_9 | RESULT float32 3:2x1024x512 SCST Reshape _unsafe_view_9
146 = | RESULT float32 3:2x1024x512 KXUY Add add_7 | RESULT float32 3:2x1024x512 KXUY Add output_35
147 = | RESULT float32 3:2x1024x512 OOJQ Pow pow_3 | RESULT float32 3:2x1024x512 OOJQ Pow pow_3
148 = | RESULT float32 3:2x1024x1 BBQQ ReduceMean mean_2 | RESULT float32 3:2x1024x1 BBQQ ReduceMean mean_2
149 = | RESULT float32 3:2x1024x1 BBQQ Add add_8 | RESULT float32 3:2x1024x1 BBQQ Add add_8
150 = | RESULT float32 3:2x1024x1 OONN Sqrt _val_596 | RESULT float32 3:2x1024x1 OONN Sqrt _onx_sqrt_add_80
151 = | RESULT float32 3:2x1024x1 HHEH Reciprocal rsqrt_2 | RESULT float32 3:2x1024x1 HHEH Reciprocal output_36
152 = | RESULT float32 3:2x1024x512 IPHM Mul mul_13 | RESULT float32 3:2x1024x512 IPHM Mul output_37
153 = | RESULT float32 3:2x1024x512 IPHM Mul mul_14 | RESULT float32 3:2x1024x512 IPHM Mul output_0
154 + | | RESULT float32 3:2x1024x2000 EBAV Reshape output_32
155 + | | RESULT float32 3:2x1024x2000 ZDZH Reshape output_29
156 + | | RESULT float32 2:2048x512 SIYF Identity output_31
157 ~ | RESULT float32 3:16x1024x1024 OOON Transpose transpose_7 | RESULT float32 3:16x1024x1024 OOON Reshape output_19
158 + | | RESULT float32 3:16x64x1024 WHMZ Reshape output_17
159 + | | RESULT float32 3:16x1024x64 KKGC Reshape output_16
160 - | RESULT float32 4:2x8x1024x1024 OOON Identity detach_13 |
161 ~ | RESULT float32 3:16x1024x64 WHMZ Transpose transpose_10 | RESULT float32 3:16x1024x64 AHLI Reshape output_20
162 ~ | RESULT float32 3:16x64x1024 KKGC Transpose transpose_9 | RESULT float32 2:2048x512 NBBI Identity output_11
163 ~ | RESULT float32 3:16x64x1024 AHLI Transpose transpose_8 | RESULT float32 2:2048x512 NBBI Identity output_13
164 - | OUTPUT float32 3:2x1024x512 BFOY embedding |
165 ~ | OUTPUT float32 2:512x512 EHWU t_33 | RESULT float32 2:512x512 ZVHA Transpose output_8
166 ~ | OUTPUT float32 2:512x512 WEYY t_29 | RESULT float32 2:512x512 BYCS Transpose output_10
167 ~ | OUTPUT float32 2:512x512 CBSF t_25 | RESULT float32 2:512x512 CEYX Transpose output_12
168 ~ | OUTPUT float32 2:512x512 JXFC t_21 | RESULT float32 2:512x512 FDBD Transpose output_21
169 + | | RESULT float32 2:512x2000 CUNX Transpose output_26
170 + | | RESULT float32 2:512x2000 ADXJ Transpose output_30
171 ~ | OUTPUT float32 2:2000x512 KOCG t_17 | RESULT float32 2:2000x512 UYAO Transpose output_33
172 + | | OUTPUT float32 3:2x1024x512 IPHM output_0
173 + | | OUTPUT int64 2:2x1024 ZHFJ output_1
174 + | | OUTPUT float32 1:512 YYYY output_2
175 + | | OUTPUT float32 1:512 YYYY output_3
176 + | | OUTPUT float32 1:512 YYYY output_4
177 + | | OUTPUT float32 3:2x1024x512 BFOY output_5
178 - | OUTPUT float32 2:2000x512 DKHP t_13 |
179 - | OUTPUT float32 2:512x2000 STEQ t_9 |
180 = | OUTPUT float32 3:2x1024x1 UTBY rsqrt | OUTPUT float32 3:2x1024x1 UTBY output_6
181 + | | OUTPUT float32 3:2x1024x512 NBBI output_7
182 + | | OUTPUT float32 2:512x512 ZVHA output_8
183 = | OUTPUT float32 2:2048x512 NBBI view_4 | OUTPUT float32 2:2048x512 NBBI output_9
184 + | | OUTPUT float32 2:512x512 BYCS output_10
185 - | OUTPUT float32 3:1x1024x64 VFPY cat |
186 ~ | OUTPUT float32 3:16x64x1024 AHLI transpose_8 | OUTPUT float32 2:2048x512 NBBI output_11
187 + | | OUTPUT float32 2:512x512 CEYX output_12
188 ~ | OUTPUT float32 3:16x64x1024 KKGC transpose_9 | OUTPUT float32 2:2048x512 NBBI output_13
189 + | | OUTPUT float32 4:1x1x1024x64 CJYF output_14
190 + | | OUTPUT float32 4:1x1x1024x64 GSEC output_15
191 ~ | OUTPUT float32 3:16x1024x64 WHMZ transpose_10 | OUTPUT float32 3:16x1024x64 KKGC output_16
192 + | | OUTPUT float32 3:16x64x1024 WHMZ output_17
193 = | OUTPUT float32 4:2x8x1024x1024 OOON detach_13 | OUTPUT float32 4:2x8x1024x1024 OOON output_18
194 = | OUTPUT float32 3:16x1024x1024 OOON transpose_7 | OUTPUT float32 3:16x1024x1024 OOON output_19
195 + | | OUTPUT float32 3:16x1024x64 AHLI output_20
196 + | | OUTPUT float32 2:512x512 FDBD output_21
197 = | OUTPUT float32 2:2048x512 EEPN view_14 | OUTPUT float32 2:2048x512 EEPN output_22
198 ~ | OUTPUT float32 2:2048x512 RRNI mm_3 | OUTPUT float32 3:2x1024x512 SWCG output_23
199 = | OUTPUT float32 3:2x1024x1 GGHJ rsqrt_1 | OUTPUT float32 3:2x1024x1 GGHJ output_24
200 + | | OUTPUT float32 3:2x1024x512 SIYF output_25
201 + | | OUTPUT float32 2:512x2000 CUNX output_26
202 = | OUTPUT float32 2:2048x512 SIYF view_15 | OUTPUT float32 2:2048x512 SIYF output_27
203 ~ | OUTPUT float32 2:2048x2000 VALD mm_4 | OUTPUT float32 3:2x1024x2000 VALD output_28
204 + | | OUTPUT float32 3:2x1024x2000 ZDZH output_29
205 + | | OUTPUT float32 2:512x2000 ADXJ output_30
206 + | | OUTPUT float32 2:2048x512 SIYF output_31
207 ~ | OUTPUT float32 2:2048x2000 EBAV mm_5 | OUTPUT float32 3:2x1024x2000 EBAV output_32
208 + | | OUTPUT float32 2:2000x512 UYAO output_33
209 = | OUTPUT float32 2:2048x2000 EWBL view_17 | OUTPUT float32 2:2048x2000 EWBL output_34
210 = | OUTPUT float32 3:2x1024x512 KXUY add_7 | OUTPUT float32 3:2x1024x512 KXUY output_35
211 = | OUTPUT float32 3:2x1024x1 HHEH rsqrt_2 | OUTPUT float32 3:2x1024x1 HHEH output_36
212 = | OUTPUT float32 3:2x1024x512 IPHM mul_14 | OUTPUT float32 3:2x1024x512 IPHM output_37
Total running time of the script: (0 minutes 43.282 seconds)
Related examples

201: Evaluate different ways to export a torch model to ONNX