Note
Go to the end to download the full example code.
301: Compares LLAMA exporters for onnxrt backend¶
The script compares exported models in pytorch using onnxrt backend. It tries to do a side by side of the execution of both models.
To run the script:
python _doc/examples/plot_llama_diff_dort --help
The following example compares the forward step for mixed precision on cuda and produces all the intermediate onnx graphs.
You may use --mixed=1
to compare the backward graphs.
Some helpers¶
from experimental_experiment.args import get_parsed_args
script_args = get_parsed_args(
"plot_llama_diff_export",
description=__doc__,
part=("model", "one value among model, ..."),
ortopt=(1, "run onnxruntime optimization"),
backward=(0, "does one operator for backward"),
cuda=(0, "use cuda or not"),
mixed=(0, "use miwed precision"),
opset=(18, "onnx opset"),
expose="part,exporter,ortopt,cuda,mixed,opset",
)
import copy
import os
import warnings
import logging
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import onnxruntime
has_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers()
except ImportError:
print("onnxruntime not available.")
import sys
sys.exit(0)
import onnx
from onnx_array_api.reference import compare_onnx_execution, ExtendedReferenceEvaluator
import torch
from torch._dynamo.backends.common import aot_autograd
from experimental_experiment.ext_test_case import unit_test_going
from experimental_experiment.convert.convert_helper import (
ort_optimize,
optimize_model_proto_oxs,
)
from experimental_experiment.torch_models.llama_helper import get_llama_model
from experimental_experiment.torch_models.dump_helper import (
assert_all_close,
dump_onnx,
reorder_functions_in_proto,
inputs_from_onnx_model,
build_matching_inputs,
results_to_string,
)
from experimental_experiment.torch_models.training_helper import (
train_loop,
make_aot_ort,
)
from experimental_experiment.torch_dynamo import (
onnx_debug_backend,
get_decomposition_table,
)
has_cuda = has_cuda and torch.cuda.device_count() > 0
logging.disable(logging.ERROR)
provider = "cuda" if has_cuda else "cpu"
The exporting functions¶
print(f"part={script_args.part}")
ortopt = script_args.ortopt in (1, "1")
print(f"ortopt={ortopt}")
backward = script_args.backward in (1, "1")
print(f"backward={backward}")
use_cuda = script_args.cuda in (1, "1")
print(f"cuda={use_cuda}")
use_mixed = script_args.mixed in (1, "1")
print(f"mixed={use_mixed}")
opset = int(script_args.opset)
print(f"opset={opset}")
part=model
ortopt=True
backward=False
cuda=False
mixed=False
opset=18
Model and data¶
if unit_test_going():
kwargs = dict(input_dims=[(2, 1024)] * 2)
else:
kwargs = dict(
input_dims=[(2, 1024)] * 2,
_attn_implementation="eager",
num_hidden_layers=1,
hidden_size=512,
vocab_size=4000,
intermediate_size=2000,
max_position_embeddings=2048,
num_attention_heads=8,
)
if script_args.part == "model":
model, inputs = get_llama_model(**kwargs)
else:
raise RuntimeError(f"Unexpected value for part={script_args.part!r}")
if use_cuda:
model = model.to("cuda")
inputs = [[i.to("cuda") for i in inp] for inp in inputs]
print(f"simple run with {len(inputs)} inputs")
if backward:
if use_mixed:
assert use_cuda, "mixed precision only works with cuda"
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch.cuda.synchronize()
expected = train_loop(copy.deepcopy(model), *inputs[0])
torch.cuda.synchronize()
else:
expected = train_loop(copy.deepcopy(model), *inputs[0])
print(
f"-- eager mode worked, {len(expected)} gradients, first one is "
f"{expected[0].shape}, {expected[0].dtype}"
)
else:
if use_mixed:
assert use_cuda, "mixed precision only works with cuda"
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch.cuda.synchronize()
expected = model(*inputs[0])
torch.cuda.synchronize()
else:
expected = model(*inputs[0])
print(results_to_string(expected))
simple run with 2 inputs
1 results
torch.float32 (2, 1024, 512) [sum=2.15e+04]
Exporting¶
if hasattr(torch._dynamo.variables.misc, "LoggingLoggerVariable"):
# A tweak to make torch.export.export work.
torch._dynamo.variables.misc.LoggingLoggerVariable.call_method = lambda *_, **__: None
folder = "dump_models"
storage = {}
if backward:
# onnxrt backend
local_aot_ort, _ = make_aot_ort(dynamic=False, rewrite=True)
optimized_mod = torch.compile(
copy.deepcopy(model), backend=local_aot_ort, dynamic=False, fullgraph=True
)
with dump_onnx("llama_onnxrt", folder=folder, clean=True):
if use_mixed:
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch.cuda.synchronize()
expected_onnxrt = train_loop(optimized_mod, *inputs[0])
torch.cuda.synchronize()
else:
expected_onnxrt = train_loop(optimized_mod, *inputs[0])
assert_all_close(expected[0], expected_onnxrt[0], atol=1e-3)
print(
f"-- onnxrt backend worked, {len(expected_onnxrt)} gradients, first one is "
f"{expected_onnxrt[0].shape}, {expected_onnxrt[0].dtype}"
)
# debugging backend
aot_compiler = aot_autograd(
fw_compiler=lambda *args, **kwargs: onnx_debug_backend(
*args,
dump_prefix=os.path.join(folder, "llama_debug"),
target_opset=opset,
storage=storage,
**kwargs,
),
decompositions=get_decomposition_table(),
)
onnx_mod = torch.compile(copy.deepcopy(model), backend=aot_compiler, fullgraph=True)
if use_mixed:
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch.cuda.synchronize()
got = train_loop(onnx_mod, *inputs[0])
torch.cuda.synchronize()
else:
got = train_loop(onnx_mod, *inputs[0])
assert_all_close(expected[0], got[0], atol=1e-2 if use_mixed else 1e-4)
print(
f"-- debug backend worked, {len(got)} gradients, first one is "
f"{got[0].shape}, {got[0].dtype}"
)
else:
# onnxrt backend
local_aot_ort, _ = make_aot_ort(dynamic=True, rewrite=True)
optimized_mod = torch.compile(model, backend=local_aot_ort, fullgraph=True)
with dump_onnx("llama_onnxrt", folder=folder, clean=True):
if use_mixed:
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch.cuda.synchronize()
expected_onnxrt = optimized_mod(*inputs[0])
torch.cuda.synchronize()
else:
expected_onnxrt = optimized_mod(*inputs[0])
assert_all_close(expected, expected_onnxrt, atol=1e-2)
# debugging backend
aot_compiler = aot_autograd(
fw_compiler=lambda *args, **kwargs: onnx_debug_backend(
*args,
dump_prefix=os.path.join(folder, "llama_debug"),
target_opset=17,
storage=storage,
**kwargs,
)
)
onnx_mod = torch.compile(model, backend=aot_compiler, fullgraph=True)
if use_mixed:
with torch.autocast(device_type="cuda", dtype=torch.float16):
got = onnx_mod(*inputs[0])
else:
try:
got = onnx_mod(*inputs[0])
except Exception as e:
print(f"ERROR: {e}")
got = None
if got is not None:
assert_all_close(expected, got, atol=1 if use_mixed else 1e-3)
~/vv/this312/lib/python3.12/site-packages/torch/onnx/_internal/_exporter_legacy.py:91: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
warnings.warn(
~/vv/this312/lib/python3.12/site-packages/torch/onnx/_internal/fx/onnxfunction_dispatcher.py:394: FutureWarning: 'onnxscript.values.TracedOnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
self.param_schema = self.onnxfunction.param_schemas()
For forward, there are two files, one onnx model and the graph module printed in a txt file. For backward, there are two onnx models. Then it is multiplied by the number of backends.
models = os.listdir(folder)
print(f"exported models: {models}")
exported models: ['llama_debug_0.onnx', 'llama_onnxrt_0.txt', 'llama_debug_0.txt', 'llama_onnxrt_0.onnx']
Inputs used by the debug backend
-- input0 int64 (2, 1024)
-- input1 float32 (4000, 512)
-- input2 float32 (2, 1024)
-- input3 float32 (32,)
-- input4 float32 (512,)
-- input5 float32 (512, 512)
-- input6 float32 (512, 512)
-- input7 float32 (512, 512)
-- input8 float32 (512, 512)
-- input9 float32 (512,)
-- input10 float32 (2000, 512)
-- input11 float32 (2000, 512)
-- input12 float32 (512, 2000)
-- input13 float32 (512,)
Let’s the first line of the graph module
if "instance" in storage:
graph_module = storage["instance"][0]["graph_module"]
print("\n".join(str(graph_module.graph).split("\n")[:10]))
graph():
%primals_1 : [num_users=2] = placeholder[target=primals_1]
%primals_2 : [num_users=1] = placeholder[target=primals_2]
%primals_3 : [num_users=1] = placeholder[target=primals_3]
%primals_4 : [num_users=1] = placeholder[target=primals_4]
%primals_5 : [num_users=2] = placeholder[target=primals_5]
%primals_6 : [num_users=1] = placeholder[target=primals_6]
%primals_7 : [num_users=1] = placeholder[target=primals_7]
%primals_8 : [num_users=1] = placeholder[target=primals_8]
%primals_9 : [num_users=1] = placeholder[target=primals_9]
Comparison and execution¶
if "instance" in storage:
if backward:
print(f"-- {len(storage['instance'])} onnx models were creates")
for i, inst in enumerate(storage["instance"]):
print(f" model {i}: {len(inst['inputs'])} runs")
# deal with backward
onnx_models = list(sorted([m for m in models if m.endswith(".onnx")]))
assert len(onnx_models) == 4, f"unexpected value {onnx_models}"
onnx_models = list(sorted([m for m in models if m.endswith(".onnx") and "_1" in m]))
assert len(onnx_models) == 2, f"unexpected value {onnx_models}"
model_onnxrt = os.path.join(folder, onnx_models[1])
model_debug = os.path.join(folder, onnx_models[0])
else:
onnx_models = list(sorted([m for m in models if m.endswith(".onnx")]))
if len(onnx_models) == 2:
model_onnxrt = os.path.join(folder, onnx_models[1])
model_debug = os.path.join(folder, onnx_models[0])
else:
model_debug = os.path.join(folder, onnx_models[0])
# the following error may appear:
# Node type 'Rank' from domain 'pkg.onnxscript.torch_lib.common' is unknown
print(f"One model is missing, onnx_models={onnx_models}")
model_onnxrt = model_debug
print(f"model_onnxrt={model_onnxrt}")
print(f"model_debug={model_debug}")
model_onnxrt=dump_models/llama_onnxrt_0.onnx
model_debug=dump_models/llama_debug_0.onnx
The inputs of both models
if "instance" in storage:
print("onnxrt:", inputs_from_onnx_model(model_onnxrt))
print("debug:", inputs_from_onnx_model(model_debug))
onnxrt: [('INPUT', 'primals_13', 1, (512, 2000)), ('INPUT', 'primals_12', 1, (2000, 512)), ('INPUT', 'primals_11', 1, (2000, 512)), ('INPUT', 'primals_9', 1, (512, 512)), ('INPUT', 'primals_8', 1, (512, 512)), ('INPUT', 'primals_7', 1, (512, 512)), ('INPUT', 'primals_6', 1, (512, 512)), ('INPUT', 'primals_4', 1, (32,)), ('INPUT', 'primals_3', 1, (2, 1024)), ('INPUT', 'primals_2', 1, (4000, 512)), ('INPUT', 'primals_1', 7, (2, 1024)), ('INPUT', 'primals_5', 1, (512,)), ('INPUT', 'primals_10', 1, (512,)), ('INPUT', 'primals_14', 1, (512,))]
debug: [('INPUT', 'input0', 7, (2, 1024)), ('INPUT', 'input1', 1, (4000, 512)), ('INPUT', 'input2', 1, (2, 1024)), ('INPUT', 'input3', 1, (32,)), ('INPUT', 'input4', 1, (512,)), ('INPUT', 'input5', 1, (512, 512)), ('INPUT', 'input6', 1, (512, 512)), ('INPUT', 'input7', 1, (512, 512)), ('INPUT', 'input8', 1, (512, 512)), ('INPUT', 'input9', 1, (512,)), ('INPUT', 'input10', 1, (2000, 512)), ('INPUT', 'input11', 1, (2000, 512)), ('INPUT', 'input12', 1, (512, 2000)), ('INPUT', 'input13', 1, (512,))]
Inputs are not the same. The first model has more and some inputs were moved into the initializer list into for model_debug.
if "instance" in storage:
print("debug:", inputs_from_onnx_model(model_debug, init=True))
debug: [('INPUT', 'input0', 7, (2, 1024)), ('INPUT', 'input1', 1, (4000, 512)), ('INPUT', 'input2', 1, (2, 1024)), ('INPUT', 'input3', 1, (32,)), ('INPUT', 'input4', 1, (512,)), ('INPUT', 'input5', 1, (512, 512)), ('INPUT', 'input6', 1, (512, 512)), ('INPUT', 'input7', 1, (512, 512)), ('INPUT', 'input8', 1, (512, 512)), ('INPUT', 'input9', 1, (512,)), ('INPUT', 'input10', 1, (2000, 512)), ('INPUT', 'input11', 1, (2000, 512)), ('INPUT', 'input12', 1, (512, 2000)), ('INPUT', 'input13', 1, (512,)), ('INIT', 'init7_s_0', 7, ()), ('INIT', 'init7_s_1024', 7, ()), ('INIT', 'init7_s_1', 7, ()), ('INIT', 'init7_s2_1024_1024', 7, (2,)), ('INIT', 'init7_s2_-1_1', 7, (2,)), ('INIT', 'init7_s1_1', 7, (1,)), ('INIT', 'init7_s4_2_1_1024_1024', 7, (4,)), ('INIT', 'init1_s_', 1, ()), ('INIT', 'init1_s1_', 1, (1,)), ('INIT', 'init1_s1_2', 1, (1,)), ('INIT', 'init1_s_3', 1, ()), ('INIT', 'init7_s2_2048_512', 7, (2,)), ('INIT', 'init7_s3_2_1024_512', 7, (3,)), ('INIT', 'init7_s4_2_1024_-1_64', 7, (4,)), ('INIT', 'init7_s3_16_1024_64', 7, (3,)), ('INIT', 'init7_s3_16_64_1024', 7, (3,)), ('INIT', 'init1_s_4', 1, ()), ('INIT', 'init7_s3_16_1024_1024', 7, (3,)), ('INIT', 'init7_s3_2_1024_2000', 7, (3,)), ('INIT', 'init7_s2_2048_2000', 7, (2,)), ('INIT', 'init7_s2_0_1', 7, (2,)), ('INIT', 'init7_s2_1_2', 7, (2,)), ('INIT', 'init7_s2_0_2', 7, (2,)), ('INIT', 'init7_s2_32_32', 7, (2,))]
Optimization and Verification¶
Let’s try the model with a python backend (reference implementation). First step, onnxscript uses many functions. The reference evaluation expects every function to be defined so the order of functions in the model matters. No recursivity is allowed by this runtime. We need to reorder as function Rank is usually placed at the end of the model.
if "instance" in storage:
reorder_functions_in_proto(model_onnxrt)
Let’s load the model and optimize them.
if "instance" in storage:
debug = onnx.load(model_debug)
try:
onnxrt = optimize_model_proto_oxs(onnx.load(model_onnxrt))
except ImportError as e:
print("missing library", e)
onnxrt = debug
Applied 21 of general pattern rewrite rules.
Applied 2 of general pattern rewrite rules.
Let’s apply onnxruntime optimization
if "instance" in storage and ortopt:
providers = (
[("CUDAExecutionProvider", {}), ("CPUExecutionProvider", {})]
if use_cuda
else ["CPUExecutionProvider"]
)
with open(model_onnxrt.replace(".onnx", ".before.opt.onnx"), "wb") as f:
f.write(onnxrt.SerializeToString())
print(f"run onnxruntime optimization on {model_onnxrt}")
optimized = model_onnxrt.replace(".onnx", ".opt.onnx")
ort_optimize(onnxrt, output=optimized, providers=providers)
onnxrt = onnx.load(optimized)
print(f"run onnxruntime optimization on {model_debug}")
optimized = model_debug.replace(".onnx", ".opt.onnx")
ort_optimize(debug, output=optimized, disable_aot=True, providers=providers)
debug = onnx.load(optimized)
run onnxruntime optimization on dump_models/llama_onnxrt_0.onnx
run onnxruntime optimization on dump_models/llama_debug_0.onnx
For what’s following, we need to build two lists of matching inputs.
if "instance" in storage:
print("build_matching_inputs")
feedsrt = build_matching_inputs(model_debug, feeds, model_onnxrt)
print("done")
build_matching_inputs
done
We check both models are running.
if "instance" in storage:
out_onnxrt = ExtendedReferenceEvaluator(onnxrt).run(None, feedsrt)
out_debug = ExtendedReferenceEvaluator(debug).run(None, feeds)
assert out_onnxrt
assert out_debug
# assert_all_close(out_onnxrt, out_debug)
Side by side
[compare_onnx_execution] execute with 2 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 156 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 156 results (first model)
[compare_onnx_execution] got 171 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 201 pairs
[compare_onnx_execution] done
001 ~ | INITIA float32 4:1024x1x2x1024 ???? _val_425 | INITIA int64 1:2 GGAA init7_s2_32_32
002 - | INITIA float32 3:1x1x1024 KAQG _to_copy |
003 - | INITIA int64 2:1024x1 KAQG _val_423 |
004 ~ | INITIA int64 1:2 UYAA _val_600 | INITIA int64 1:2 BCAA init7_s2_1_2
005 ~ | INITIA float32 AAAA scalar_tensor_default | INITIA float32 3:1x1x1024 KAQG _to_copy
006 ~ | INITIA int64 2:1x1 AAAA _val_448 | INITIA int64 1:2 ACAA init7_s2_0_2
007 = | INITIA float32 4:2x1x1024x1024 ???? expand_1 | INITIA float32 4:2x1x1024x1024 ???? expand_1
008 - | INITIA int64 1:3 QKMA _val_476 |
009 ~ | INITIA int64 1:2 ACAA val_2 | INITIA int64 1:1 BAAA init7_s1_1
010 ~ | INITIA int64 1:2 BCAA val_0 | INITIA float32 1:1 AAAA _reshape_init1_s_303
011 - | INITIA float32 AAAA _val_539 |
012 - | INITIA float32 AAAA _val_613 |
013 ~ | INITIA int64 1:3 QMKA _val_498 | INITIA int64 1:3 CKYA init7_s3_2_1024_2000
014 ~ | INITIA int64 1:3 CKSA _val_604 | INITIA float32 1:1 ?AAA init1_s1_
015 ~ | INITIA int64 1:3 QKKA _val_552 | INITIA float32 1:1 CAAA init1_s1_2
016 ~ | INITIA int64 1:2 USAA _val_215 | INITIA int64 1:2 UYAA init7_s2_2048_2000
017 - | INITIA float32 4:1x2x1024x1024 ???? _val_450 |
018 ~ | INITIA int64 1:4 CIKK _val_537 | INITIA int64 1:2 USAA init7_s2_2048_512
019 ~ | INITIA int64 1:3 CKYA _val_594 | INITIA int64 1:3 CKSA init7_s3_2_1024_512
020 - | INITIA float32 ?AAA _val_400 |
021 - | INITIA int64 CAAA _val_607 |
022 ~ | INITIA int64 1:4 CKIM Reshape_380_new_shape | INITIA int64 1:4 CKZM init7_s4_2_1024_-1_64
023 ~ | INITIA int64 1:1 ZAAA _val_611 | INITIA int64 1:3 QKMA init7_s3_16_1024_64
024 ~ | INITIA int64 1:4 CIKM _val_558 | INITIA int64 1:3 QMKA init7_s3_16_64_1024
025 ~ | INITIA int64 1:2 GGAA splits | INITIA float32 1:1 AAAA _reshape_init1_s_0
026 ~ | INITIA int64 1:2 GGAA splits_token_10 | INITIA int64 1:3 QKKA init7_s3_16_1024_1024
027 + | | INPUT int64 2:2x1024 IQUS input0
028 + | | INPUT float32 2:4000x512 NYUT input1
029 + | | INPUT float32 2:2x1024 BACA input2
030 ~ | INITIA int64 1:1 BAAA _val_320 | INPUT float32 1:32 DAAA input3
031 + | | INPUT float32 1:512 YYYY input4
032 - | INPUT float32 2:512x2000 MIVH primals_13 |
033 - | INPUT float32 2:2000x512 FHMB primals_12 |
034 - | INPUT float32 2:2000x512 QJXU primals_11 |
035 = | INPUT float32 2:512x512 YHZA primals_9 | INPUT float32 2:512x512 YHZA input5
036 = | INPUT float32 2:512x512 ABYA primals_8 | INPUT float32 2:512x512 ABYA input6
037 = | INPUT float32 2:512x512 HAAA primals_7 | INPUT float32 2:512x512 HAAA input7
038 = | INPUT float32 2:512x512 GAEX primals_6 | INPUT float32 2:512x512 GAEX input8
039 - | INPUT float32 1:32 DAAA primals_4 |
040 - | INPUT float32 2:2x1024 BACA primals_3 |
041 - | INPUT float32 2:4000x512 NYUT primals_2 |
042 - | INPUT int64 2:2x1024 IQUS primals_1 |
043 = | INPUT float32 1:512 YYYY primals_5 | INPUT float32 1:512 YYYY input9
044 + | | INPUT float32 2:2000x512 FHMB input10
045 + | | INPUT float32 2:2000x512 QJXU input11
046 + | | INPUT float32 2:512x2000 MIVH input12
047 = | INPUT float32 1:512 YYYY primals_10 | INPUT float32 1:512 YYYY input13
048 + | | RESULT float32 1:512 YYYY Identity output_4
049 + | | RESULT float32 1:512 YYYY Identity output_3
050 ~ | INPUT float32 1:512 YYYY primals_14 | RESULT float32 1:512 YYYY Identity output_2
051 + | | RESULT int64 2:2x1024 IQUS Identity output_1
052 = | RESULT float32 3:1x32x1 DAAA Unsqueeze unsqueeze_8 | RESULT float32 3:1x32x1 DAAA Unsqueeze unsqueeze_8
053 = | RESULT float32 3:1x32x1024 EFXM MatMul bmm | RESULT float32 3:1x32x1024 EFXM MatMul bmm
054 = | RESULT float32 3:1x64x1024 JKJK Concat Concat_408 | RESULT float32 3:1x64x1024 JKJK Concat cat_token_5
055 ~ | RESULT float32 3:1x1024x64 VFPY Transpose cat | RESULT float32 3:1x64x1024 RMRM Sin sin_token_7
056 ~ | RESULT float32 3:1x1024x64 GSEC Sin sin | RESULT float32 4:1x1x64x1024 RMRM Unsqueeze unsqueeze10
057 - | RESULT float32 4:1x1x1024x64 GSEC Unsqueeze unsqueeze_11 |
058 = | RESULT float32 4:1x1024x1x64 GSEC Transpose Transpose_token_5_out0 | RESULT float32 4:1x1024x1x64 GSEC Transpose Transpose_token_9_out0
059 = | RESULT float32 3:2x1024x512 ANUI Gather embedding | RESULT float32 3:2x1024x512 ANUI Gather output_5
060 = | RESULT float32 3:2x1024x512 BABB Pow pow_1 | RESULT float32 3:2x1024x512 BABB Pow pow_1
061 = | RESULT float32 3:2x1024x1 AAAA ReduceMean mean | RESULT float32 3:2x1024x1 AAAA ReduceMean mean
062 = | RESULT float32 3:2x1024x1 AAAA Add add_1 | RESULT float32 3:2x1024x1 AAAA Add add_1
063 = | RESULT float32 3:2x1024x1 KKKK Sqrt _val_163 | RESULT float32 3:2x1024x1 KKKK Sqrt _onx_sqrt_add_10
064 = | RESULT float32 3:2x1024x1 EATD Reciprocal rsqrt | RESULT float32 3:2x1024x1 EATD Reciprocal output_6
065 = | RESULT float32 3:2x1024x512 RIFF Mul mul_3 | RESULT float32 3:2x1024x512 RIFF Mul output_7
066 = | RESULT float32 3:2x1024x512 RIFF Mul mul_4 | RESULT float32 3:2x1024x512 RIFF Mul mul_4
067 = | RESULT float32 2:2048x512 RIFF Reshape view_4 | RESULT float32 2:2048x512 RIFF Reshape output_9
068 ~ | RESULT float32 2:2048x512 VIXJ FusedMatMul mm_1 | RESULT float32 2:2048x512 FHSC Gemm mm_1
069 ~ | RESULT float32 4:2x1024x8x64 VIXJ Reshape view_7 | RESULT float32 4:2x1024x8x64 FHSC Reshape view_7
070 ~ | RESULT float32 4:2x1024x8x32 JWNM Split Slice_498 | RESULT float32 4:2x1024x8x32 CSPE Split SlicesSplitPattern--slide_Tensor
071 ~ | RESULT float32 4:2x1024x8x32 MLKY Split Slice_515 | RESULT float32 4:2x1024x8x32 CPCZ Split SlicesSplitPattern--slide_Tensor
072 ~ | RESULT float32 4:2x1024x8x32 OPQC Neg Neg_526 | RESULT float32 4:2x1024x8x32 YLYB Neg neg2
073 ~ | RESULT float32 4:2x1024x8x64 XKEP Concat Concat_536 | RESULT float32 4:2x1024x8x64 ADOF Concat cat3
074 ~ | RESULT float32 4:2x1024x8x64 YLBY Mul Mul_546 | RESULT float32 4:2x1024x8x64 PWPL Mul mul_Tensor15
075 + | | RESULT float32 3:1x64x1024 NHNH Cos cos_token_12
076 ~ | RESULT float32 3:1x1024x64 CJYF Cos cos | RESULT float32 4:1x1x64x1024 NHNH Unsqueeze unsqueeze9
077 - | RESULT float32 4:1x1x1024x64 CJYF Unsqueeze unsqueeze_10 |
078 = | RESULT float32 4:1x1024x1x64 CJYF Transpose Transpose_token_7_out0 | RESULT float32 4:1x1024x1x64 CJYF Transpose Transpose_token_14_out0
079 ~ | RESULT float32 4:2x1024x8x64 RYQN Mul Mul_541 | RESULT float32 4:2x1024x8x64 JFBY Mul mul_Tensor14
080 ~ | RESULT float32 4:2x1024x8x64 QKSL Add Add_550 | RESULT float32 4:2x1024x8x64 XBQJ Add add_Tensor4
081 ~ | RESULT float32 4:2x8x64x1024 JQRL Transpose transpose_4 | RESULT float32 4:2x8x64x1024 SHZB Transpose transpose_4
082 + | | RESULT float32 4:1x1x1024x64 GSEC Transpose output_15
083 - | RESULT float32 3:16x64x1024 JQRL Reshape _unsafe_view_4 |
084 ~ | RESULT float32 2:2048x512 LRWZ FusedMatMul mm | RESULT float32 2:2048x512 SSQU Gemm mm
085 ~ | RESULT float32 4:2x1024x8x64 LRWZ Reshape view_5 | RESULT float32 4:2x1024x8x64 SSQU Reshape view_5
086 ~ | RESULT float32 4:2x8x1024x64 WHUB Transpose transpose_1 | RESULT float32 4:2x8x1024x64 HEPV Transpose transpose_1
087 ~ | RESULT float32 4:2x8x1024x32 VAVE Split slice_24 | RESULT float32 4:2x8x1024x32 QLUT Split slice_24
088 ~ | RESULT float32 4:2x8x1024x32 BHZX Split slice_25 | RESULT float32 4:2x8x1024x32 QUVB Split slice_25
089 ~ | RESULT float32 4:2x8x1024x32 ZTBD Neg neg | RESULT float32 4:2x8x1024x32 KGFZ Neg neg
090 ~ | RESULT float32 4:2x8x1024x64 UTXH Concat cat_1 | RESULT float32 4:2x8x1024x64 ASZT Concat cat_1
091 ~ | RESULT float32 4:2x8x1024x64 XSAV Mul mul_6 | RESULT float32 4:2x8x1024x64 ZXPQ Mul mul_6
092 + | | RESULT float32 4:1x1x1024x64 CJYF Transpose output_14
093 ~ | RESULT float32 4:2x8x1024x64 MZVS Mul mul_5 | RESULT float32 4:2x8x1024x64 ESYL Mul mul_5
094 ~ | RESULT float32 4:2x8x1024x64 ISVO Add add_2 | RESULT float32 4:2x8x1024x64 CPMB Add add_2
095 - | RESULT float32 3:16x1024x64 ISVO Reshape _unsafe_view_3 |
096 - | RESULT float32 3:16x1024x1024 OUZX MatMul bmm_1 |
097 - | RESULT float32 4:2x8x1024x1024 OUZX Reshape view_10 |
098 ~ | RESULT float32 4:2x8x1024x1024 MFRJ Mul mul_9 | RESULT float32 4:2x8x1024x1024 PLSF FusedMatMul _onx_mul_view_100
099 = | RESULT float32 4:2x1x1x1024 BACA Unsqueeze unsqueeze_6 | RESULT float32 4:2x1x1x1024 BACA Unsqueeze unsqueeze_6
100 = | RESULT float32 4:2x1x1024x1024 ???? Add add | RESULT float32 4:2x1x1024x1024 ???? Add add
101 = | RESULT bool 4:2x1x1024x1024 KWTE Equal eq | RESULT bool 4:2x1x1024x1024 KWTE Equal eq
102 = | RESULT float32 4:2x1x1024x1024 ???? Where masked_fill | RESULT float32 4:2x1x1024x1024 ???? Where masked_fill
103 - | RESULT float32 4:1024x1x2x1024 ???? Transpose _val_424 |
104 - | RESULT float32 4:1024x1x2x1024 ???? ScatterND _val_426 |
105 - | RESULT float32 4:1x2x1024x1024 ???? Transpose _val_449 |
106 - | RESULT float32 4:1x2x1024x1024 ???? ScatterND _val_451 |
107 - | RESULT float32 4:2x1x1024x1024 ???? Transpose slice_scatter_1 |
108 = | RESULT float32 4:2x8x1024x1024 ???? Add add_4 | RESULT float32 4:2x8x1024x1024 ???? Add add_4
109 = | RESULT float32 4:2x8x1024x1024 OONO Softmax detach_13 | RESULT float32 4:2x8x1024x1024 OONO Softmax output_18
110 - | RESULT float32 3:16x1024x1024 OONO Reshape view_11 |
111 ~ | RESULT float32 2:2048x512 FHSC FusedMatMul mm_2 | RESULT float32 2:2048x512 VIXJ Gemm mm_2
112 ~ | RESULT float32 4:2x1024x8x64 FHSC Reshape view_9 | RESULT float32 4:2x1024x8x64 VIXJ Reshape view_9
113 ~ | RESULT float32 4:2x8x1024x64 RVOH Transpose transpose_3 | RESULT float32 4:2x8x1024x64 JURP Transpose transpose_3
114 - | RESULT float32 3:16x1024x64 RVOH Reshape _unsafe_view_5 |
115 - | RESULT float32 3:16x1024x64 CJMJ MatMul bmm_2 |
116 ~ | RESULT float32 4:2x8x1024x64 CJMJ Reshape view_12 | RESULT float32 4:2x8x1024x64 NZAW MatMul view_12
117 ~ | RESULT float32 4:2x1024x8x64 SSXY Transpose transpose_5 | RESULT float32 4:2x1024x8x64 GGRE Transpose transpose_5
118 ~ | RESULT float32 2:2048x512 SSXY Reshape view_14 | RESULT float32 2:2048x512 GGRE Reshape output_22
119 ~ | RESULT float32 2:2048x512 UUWU FusedMatMul mm_3 | RESULT float32 2:2048x512 EEEG Gemm mm_3
120 ~ | RESULT float32 3:2x1024x512 UUWU Reshape _unsafe_view_6 | RESULT float32 3:2x1024x512 EEEG Reshape _unsafe_view_6
121 ~ | RESULT float32 3:2x1024x512 UHPD Add add_5 | RESULT float32 3:2x1024x512 ERYP Add output_23
122 ~ | RESULT float32 3:2x1024x512 SVLV Pow pow_2 | RESULT float32 3:2x1024x512 MHOB Pow pow_2
123 ~ | RESULT float32 3:2x1024x1 VVLL ReduceMean mean_1 | RESULT float32 3:2x1024x1 QQLL ReduceMean mean_1
124 ~ | RESULT float32 3:2x1024x1 VVLL Add add_6 | RESULT float32 3:2x1024x1 QQLL Add add_6
125 ~ | RESULT float32 3:2x1024x1 BBYY Sqrt _val_581 | RESULT float32 3:2x1024x1 OOZZ Sqrt _onx_sqrt_add_60
126 ~ | RESULT float32 3:2x1024x1 SSCF Reciprocal rsqrt_1 | RESULT float32 3:2x1024x1 YZWY Reciprocal output_24
127 ~ | RESULT float32 3:2x1024x512 BNKB Mul mul_10 | RESULT float32 3:2x1024x512 VBVQ Mul output_25
128 ~ | RESULT float32 3:2x1024x512 BNKB Mul mul_11 | RESULT float32 3:2x1024x512 VBVQ Mul mul_11
129 ~ | RESULT float32 2:2048x512 BNKB Reshape view_15 | RESULT float32 2:2048x512 VBVQ Reshape output_27
130 ~ | RESULT float32 2:2048x2000 JCFR FusedMatMul mm_4 | RESULT float32 2:2048x2000 JQQP Gemm mm_4
131 ~ | RESULT float32 3:2x1024x2000 JCFR Reshape _unsafe_view_7 | RESULT float32 3:2x1024x2000 JQQP Reshape output_28
132 ~ | RESULT float32 3:2x1024x2000 FUAW QuickGelu silu | RESULT float32 3:2x1024x2000 OCZO Sigmoid _onx_sigmoid__unsafe_view_70
133 ~ | RESULT float32 2:2048x2000 PWLE FusedMatMul mm_5 | RESULT float32 2:2048x2000 OCZO Reshape Reshape2Of3PatternR__onx_sigmoid
134 ~ | RESULT float32 3:2x1024x2000 PWLE Reshape _unsafe_view_8 | RESULT float32 2:2048x2000 QISR Mul Reshape2Of3PatternL_output_29
135 ~ | RESULT float32 3:2x1024x2000 NBUQ Mul mul_12 | RESULT float32 2:2048x2000 QRZF Gemm mm_5
136 ~ | RESULT float32 2:2048x2000 NBUQ Reshape view_17 | RESULT float32 2:2048x2000 KCDG Mul output_34
137 ~ | RESULT float32 2:2048x512 JZPU FusedMatMul mm_6 | RESULT float32 2:2048x512 IYGR Gemm mm_6
138 ~ | RESULT float32 3:2x1024x512 JZPU Reshape _unsafe_view_9 | RESULT float32 3:2x1024x512 IYGR Reshape _unsafe_view_9
139 ~ | RESULT float32 3:2x1024x512 DFFX Add add_7 | RESULT float32 3:2x1024x512 MODG Add output_35
140 ~ | RESULT float32 3:2x1024x512 XXXB Pow pow_3 | RESULT float32 3:2x1024x512 VXRH Pow pow_3
141 ~ | RESULT float32 3:2x1024x1 ZZQQ ReduceMean mean_2 | RESULT float32 3:2x1024x1 WWOO ReduceMean mean_2
142 ~ | RESULT float32 3:2x1024x1 ZZQQ Add add_8 | RESULT float32 3:2x1024x1 WWOO Add add_8
143 ~ | RESULT float32 3:2x1024x1 KKNN Sqrt _val_615 | RESULT float32 3:2x1024x1 EEJJ Sqrt _onx_sqrt_add_80
144 ~ | RESULT float32 3:2x1024x1 LLDF Reciprocal rsqrt_2 | RESULT float32 3:2x1024x1 YXLM Reciprocal output_36
145 ~ | RESULT float32 3:2x1024x512 YKJO Mul mul_13 | RESULT float32 3:2x1024x512 SLWQ Mul output_37
146 ~ | RESULT float32 3:2x1024x512 YKJO Mul mul_14 | RESULT float32 3:2x1024x512 SLWQ Mul output_0
147 + | | RESULT float32 3:2x1024x2000 QRZF Reshape output_32
148 + | | RESULT float32 3:2x1024x2000 QISR Reshape output_29
149 + | | RESULT float32 2:2048x512 VBVQ Identity output_31
150 ~ | RESULT float32 3:16x1024x1024 OONO Transpose transpose_7 | RESULT float32 3:16x1024x1024 OONO Reshape output_19
151 + | | RESULT float32 3:16x64x1024 SHZB Reshape output_17
152 + | | RESULT float32 3:16x1024x64 CPMB Reshape output_16
153 ~ | RESULT float32 3:16x1024x64 JQRL Transpose transpose_10 | RESULT float32 3:16x1024x64 JURP Reshape output_20
154 ~ | RESULT float32 3:16x64x1024 ISVO Transpose transpose_9 | RESULT float32 2:2048x512 RIFF Identity output_11
155 ~ | RESULT float32 3:16x64x1024 RVOH Transpose transpose_8 | RESULT float32 2:2048x512 RIFF Identity output_13
156 + | | RESULT float32 2:512x512 ABCA Transpose output_8
157 + | | RESULT float32 2:512x512 UDAE Transpose output_10
158 + | | RESULT float32 2:512x512 DBWG Transpose output_12
159 + | | RESULT float32 2:512x512 CCFY Transpose output_21
160 + | | RESULT float32 2:512x2000 VIXA Transpose output_26
161 + | | RESULT float32 2:512x2000 FZCJ Transpose output_30
162 + | | RESULT float32 2:2000x512 FYEP Transpose output_33
163 + | | OUTPUT float32 3:2x1024x512 SLWQ output_0
164 + | | OUTPUT int64 2:2x1024 IQUS output_1
165 + | | OUTPUT float32 1:512 YYYY output_2
166 + | | OUTPUT float32 1:512 YYYY output_3
167 + | | OUTPUT float32 1:512 YYYY output_4
168 = | OUTPUT float32 3:2x1024x512 ANUI embedding | OUTPUT float32 3:2x1024x512 ANUI output_5
169 = | OUTPUT float32 3:2x1024x1 EATD rsqrt | OUTPUT float32 3:2x1024x1 EATD output_6
170 + | | OUTPUT float32 3:2x1024x512 RIFF output_7
171 + | | OUTPUT float32 2:512x512 ABCA output_8
172 = | OUTPUT float32 2:2048x512 RIFF view_4 | OUTPUT float32 2:2048x512 RIFF output_9
173 + | | OUTPUT float32 2:512x512 UDAE output_10
174 - | OUTPUT float32 3:1x1024x64 VFPY cat |
175 ~ | OUTPUT float32 3:16x64x1024 RVOH transpose_8 | OUTPUT float32 2:2048x512 RIFF output_11
176 + | | OUTPUT float32 2:512x512 DBWG output_12
177 ~ | OUTPUT float32 3:16x64x1024 ISVO transpose_9 | OUTPUT float32 2:2048x512 RIFF output_13
178 + | | OUTPUT float32 4:1x1x1024x64 CJYF output_14
179 + | | OUTPUT float32 4:1x1x1024x64 GSEC output_15
180 ~ | OUTPUT float32 3:16x1024x64 JQRL transpose_10 | OUTPUT float32 3:16x1024x64 CPMB output_16
181 + | | OUTPUT float32 3:16x64x1024 SHZB output_17
182 = | OUTPUT float32 4:2x8x1024x1024 OONO detach_13 | OUTPUT float32 4:2x8x1024x1024 OONO output_18
183 = | OUTPUT float32 3:16x1024x1024 OONO transpose_7 | OUTPUT float32 3:16x1024x1024 OONO output_19
184 ~ | OUTPUT float32 2:2048x512 SSXY view_14 | OUTPUT float32 3:16x1024x64 JURP output_20
185 + | | OUTPUT float32 2:512x512 CCFY output_21
186 ~ | OUTPUT float32 2:2048x512 UUWU mm_3 | OUTPUT float32 2:2048x512 GGRE output_22
187 + | | OUTPUT float32 3:2x1024x512 ERYP output_23
188 ~ | OUTPUT float32 3:2x1024x1 SSCF rsqrt_1 | OUTPUT float32 3:2x1024x1 YZWY output_24
189 + | | OUTPUT float32 3:2x1024x512 VBVQ output_25
190 + | | OUTPUT float32 2:512x2000 VIXA output_26
191 ~ | OUTPUT float32 2:2048x512 BNKB view_15 | OUTPUT float32 2:2048x512 VBVQ output_27
192 + | | OUTPUT float32 3:2x1024x2000 JQQP output_28
193 ~ | OUTPUT float32 2:2048x2000 JCFR mm_4 | OUTPUT float32 3:2x1024x2000 QISR output_29
194 + | | OUTPUT float32 2:512x2000 FZCJ output_30
195 + | | OUTPUT float32 2:2048x512 VBVQ output_31
196 ~ | OUTPUT float32 2:2048x2000 PWLE mm_5 | OUTPUT float32 3:2x1024x2000 QRZF output_32
197 + | | OUTPUT float32 2:2000x512 FYEP output_33
198 ~ | OUTPUT float32 2:2048x2000 NBUQ view_17 | OUTPUT float32 2:2048x2000 KCDG output_34
199 ~ | OUTPUT float32 3:2x1024x512 DFFX add_7 | OUTPUT float32 3:2x1024x512 MODG output_35
200 ~ | OUTPUT float32 3:2x1024x1 LLDF rsqrt_2 | OUTPUT float32 3:2x1024x1 YXLM output_36
201 ~ | OUTPUT float32 3:2x1024x512 YKJO mul_14 | OUTPUT float32 3:2x1024x512 SLWQ output_37
Total running time of the script: (0 minutes 31.934 seconds)
Related examples