Note
Go to the end to download the full example code
301: Compares LLAMA exporters for onnxrt backend¶
The script compares exported models in pytorch using onnxrt backend. It tries to do a side by side of the execution of both models.
To run the script:
python _doc/examples/plot_llama_diff_dort --help
The following example compares the forward step for mixed precision on cuda and produces all the intermediate onnx graphs.
You may use --mixed=1
to compare the backward graphs.
Some helpers¶
from experimental_experiment.args import get_parsed_args
script_args = get_parsed_args(
"plot_llama_diff_export",
description=__doc__,
part=("attention", "one value among attention, decoder, model"),
ortopt=(1, "run onnxruntime optimization"),
backward=(0, "does one operator for backward"),
cuda=(0, "use cuda or not"),
mixed=(0, "use miwed precision"),
opset=(18, "onnx opset"),
expose="part,exporter,ortopt,cuda,mixed,opset",
)
import copy
import os
import warnings
import logging
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import onnxruntime
has_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers()
except ImportError:
print("onnxruntime not available.")
import sys
sys.exit(0)
import onnx
from onnx_array_api.reference import compare_onnx_execution, ExtendedReferenceEvaluator
import torch
from torch._dynamo.backends.common import aot_autograd
from experimental_experiment.ext_test_case import unit_test_going
from experimental_experiment.convert.convert_helper import (
optimize_model_proto,
ort_optimize,
)
from experimental_experiment.torch_models.llama_helper import (
get_llama_model,
get_llama_attention,
get_llama_decoder,
)
from experimental_experiment.torch_models.dump_helper import (
assert_all_close,
dump_onnx,
reorder_functions_in_proto,
inputs_from_onnx_model,
build_matching_inputs,
results_to_string,
)
from experimental_experiment.torch_models.training_helper import (
train_loop,
make_aot_ort,
)
from experimental_experiment.torch_dynamo import (
onnx_debug_backend,
get_decomposition_table,
)
has_cuda = has_cuda and torch.cuda.is_available()
logging.disable(logging.ERROR)
provider = "cuda" if has_cuda else "cpu"
The exporting functions¶
print(f"part={script_args.part}")
ortopt = script_args.ortopt in (1, "1")
print(f"ortopt={ortopt}")
backward = script_args.backward in (1, "1")
print(f"backward={backward}")
use_cuda = script_args.cuda in (1, "1")
print(f"cuda={use_cuda}")
use_mixed = script_args.mixed in (1, "1")
print(f"mixed={use_mixed}")
opset = int(script_args.opset)
print(f"opset={opset}")
part=attention
ortopt=True
backward=False
cuda=False
mixed=False
opset=18
Model and data¶
if unit_test_going():
kwargs = dict(input_dims=[(2, 1024)] * 2)
else:
kwargs = dict(
input_dims=[(2, 1024)] * 2,
_attn_implementation="eager",
num_hidden_layers=1,
hidden_size=512,
vocab_size=4000,
intermediate_size=2000,
max_position_embeddings=2048,
num_attention_heads=8,
)
if script_args.part == "attention":
model, inputs = get_llama_attention(**kwargs)
elif script_args.part == "decoder":
model, inputs = get_llama_decoder(**kwargs)
elif script_args.part == "model":
model, inputs = get_llama_model(**kwargs)
else:
raise RuntimeError(f"Unexpected value for part={script_args.part!r}")
if use_cuda:
model = model.to("cuda")
inputs = [[i.to("cuda") for i in inp] for inp in inputs]
print(f"simple run with {len(inputs)} inputs")
if backward:
if use_mixed:
assert use_cuda, "mixed precision only works with cuda"
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch.cuda.synchronize()
expected = train_loop(copy.deepcopy(model), *inputs[0])
torch.cuda.synchronize()
else:
expected = train_loop(copy.deepcopy(model), *inputs[0])
print(
f"-- eager mode worked, {len(expected)} gradients, first one is "
f"{expected[0].shape}, {expected[0].dtype}"
)
else:
if use_mixed:
assert use_cuda, "mixed precision only works with cuda"
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch.cuda.synchronize()
expected = model(*inputs[0])
torch.cuda.synchronize()
else:
expected = model(*inputs[0])
print(results_to_string(expected))
simple run with 2 inputs
torch.float32 (2, 1024, 512) [sum=595]
Exporting¶
folder = "dump_models"
storage = {}
if backward:
# onnxrt backend
local_aot_ort, _ = make_aot_ort(dynamic=False, rewrite=True)
optimized_mod = torch.compile(
copy.deepcopy(model), backend=local_aot_ort, dynamic=False, fullgraph=True
)
with dump_onnx("llama_onnxrt", folder=folder, clean=True):
if use_mixed:
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch.cuda.synchronize()
expected_onnxrt = train_loop(optimized_mod, *inputs[0])
torch.cuda.synchronize()
else:
expected_onnxrt = train_loop(optimized_mod, *inputs[0])
assert_all_close(expected[0], expected_onnxrt[0], atol=1e-3)
print(
f"-- onnxrt backend worked, {len(expected_onnxrt)} gradients, first one is "
f"{expected_onnxrt[0].shape}, {expected_onnxrt[0].dtype}"
)
# debugging backend
aot_compiler = aot_autograd(
fw_compiler=lambda *args, **kwargs: onnx_debug_backend(
*args,
dump_prefix=os.path.join(folder, "llama_debug"),
target_opset=opset,
storage=storage,
**kwargs,
),
decompositions=get_decomposition_table(),
)
onnx_mod = torch.compile(copy.deepcopy(model), backend=aot_compiler, fullgraph=True)
if False and use_mixed:
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch.cuda.synchronize()
got = train_loop(onnx_mod, *inputs[0])
torch.cuda.synchronize()
else:
got = train_loop(onnx_mod, *inputs[0])
assert_all_close(expected[0], got[0], atol=1e-2 if use_mixed else 1e-4)
print(
f"-- debug backend worked, {len(got)} gradients, first one is "
f"{got[0].shape}, {got[0].dtype}"
)
else:
# onnxrt backend
local_aot_ort, _ = make_aot_ort(dynamic=True, rewrite=True)
optimized_mod = torch.compile(model, backend=local_aot_ort, fullgraph=True)
with dump_onnx("llama_onnxrt", folder=folder, clean=True):
if use_mixed:
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch.cuda.synchronize()
expected_onnxrt = optimized_mod(*inputs[0])
torch.cuda.synchronize()
else:
expected_onnxrt = optimized_mod(*inputs[0])
assert_all_close(expected, expected_onnxrt, atol=1e-2)
# debugging backend
aot_compiler = aot_autograd(
fw_compiler=lambda *args, **kwargs: onnx_debug_backend(
*args,
dump_prefix=os.path.join(folder, "llama_debug"),
target_opset=17,
storage=storage,
**kwargs,
)
)
onnx_mod = torch.compile(model, backend=aot_compiler, fullgraph=True)
if use_mixed:
with torch.autocast(device_type="cuda", dtype=torch.float16):
got = onnx_mod(*inputs[0])
else:
got = onnx_mod(*inputs[0])
assert_all_close(expected, got, atol=1 if use_mixed else 1e-3)
/home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/exporter.py:137: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
warnings.warn(
Applied 0 pattern rewrite rules.
Applied 0 pattern rewrite rules.
For forward, there are two files, one onnx model and the graph module printed in a txt file. For backward, there are two onnx models. Then it is multiplied by the number of backends.
models = os.listdir(folder)
print(f"exported models: {models}")
exported models: ['llama_onnxrt_0.onnx', 'llama_debug_0.onnx', 'llama_debug_0.txt', 'llama_onnxrt_0.txt']
Inputs used by the debug backend
-- input0 float32 (512, 512)
-- input1 float32 (512, 512)
-- input2 float32 (512, 512)
-- input3 float32 (512, 512)
-- input4 float32 (2048, 64)
-- input5 float32 (2048, 64)
-- input6 float32 (2, 1024, 512)
-- input7 int64 (1, 1024)
-- input8 float32 (2, 1, 1024, 1024)
Let’s the first line of the graph module
graph_module = storage["instance"][0]["graph_module"]
print("\n".join(str(graph_module.graph).split("\n")[:10]))
graph():
%primals_1 : [num_users=1] = placeholder[target=primals_1]
%primals_2 : [num_users=1] = placeholder[target=primals_2]
%primals_3 : [num_users=1] = placeholder[target=primals_3]
%primals_4 : [num_users=1] = placeholder[target=primals_4]
%primals_5 : [num_users=1] = placeholder[target=primals_5]
%primals_6 : [num_users=1] = placeholder[target=primals_6]
%primals_7 : [num_users=3] = placeholder[target=primals_7]
%primals_8 : [num_users=2] = placeholder[target=primals_8]
%primals_9 : [num_users=1] = placeholder[target=primals_9]
Comparison and execution¶
if backward:
print(f"-- {len(storage['instance'])} onnx models were creates")
for i, inst in enumerate(storage["instance"]):
print(f" model {i}: {len(inst['inputs'])} runs")
# deal with backward
onnx_models = list(sorted([m for m in models if m.endswith(".onnx")]))
assert len(onnx_models) == 4, f"unexpected value {onnx_models}"
onnx_models = list(sorted([m for m in models if m.endswith(".onnx") and "_1" in m]))
assert len(onnx_models) == 2, f"unexpected value {onnx_models}"
model_onnxrt = os.path.join(folder, onnx_models[1])
model_debug = os.path.join(folder, onnx_models[0])
else:
onnx_models = list(sorted([m for m in models if m.endswith(".onnx")]))
if len(onnx_models) == 2:
model_onnxrt = os.path.join(folder, onnx_models[1])
model_debug = os.path.join(folder, onnx_models[0])
else:
model_debug = os.path.join(folder, onnx_models[0])
# the following error may appear:
# Node type 'Rank' from domain 'pkg.onnxscript.torch_lib.common' is unknown
print(f"One model is missing, onnx_models={onnx_models}")
model_onnxrt = model_debug
print(f"model_onnxrt={model_onnxrt}")
print(f"model_debug={model_debug}")
model_onnxrt=dump_models/llama_onnxrt_0.onnx
model_debug=dump_models/llama_debug_0.onnx
The inputs of both models
print("onnxrt:", inputs_from_onnx_model(model_onnxrt))
print("debug:", inputs_from_onnx_model(model_debug))
onnxrt: [('INPUT', 'primals_4', 1, (512, 512)), ('INPUT', 'primals_1', 1, (512, 512)), ('INPUT', 'primals_7', 1, (2, 1024, 512)), ('INPUT', 'primals_2', 1, (512, 512)), ('INPUT', 'primals_3', 1, (512, 512)), ('INPUT', 'primals_5', 1, (2048, 64)), ('INPUT', 'primals_6', 1, (2048, 64)), ('INPUT', 'primals_8', 7, (1, 1024)), ('INPUT', 'primals_9', 1, (2, 1, 1024, 1024))]
debug: [('INPUT', 'input0', 1, (512, 512)), ('INPUT', 'input1', 1, (512, 512)), ('INPUT', 'input2', 1, (512, 512)), ('INPUT', 'input3', 1, (512, 512)), ('INPUT', 'input4', 1, (2048, 64)), ('INPUT', 'input5', 1, (2048, 64)), ('INPUT', 'input6', 1, (2, 1024, 512)), ('INPUT', 'input7', 7, (1, 1024)), ('INPUT', 'input8', 1, (2, 1, 1024, 1024))]
Inputs are not the same. The first model has more and some inputs were moved into the initializer list into for model_debug.
print("debug:", inputs_from_onnx_model(model_debug, init=True))
debug: [('INPUT', 'input0', 1, (512, 512)), ('INPUT', 'input1', 1, (512, 512)), ('INPUT', 'input2', 1, (512, 512)), ('INPUT', 'input3', 1, (512, 512)), ('INPUT', 'input4', 1, (2048, 64)), ('INPUT', 'input5', 1, (2048, 64)), ('INPUT', 'input6', 1, (2, 1024, 512)), ('INPUT', 'input7', 7, (1, 1024)), ('INPUT', 'input8', 1, (2, 1, 1024, 1024)), ('INIT', 'init7_s2_2048_512', 7, (2,)), ('INIT', 'init7_s3_2_1024_512', 7, (3,)), ('INIT', 'init7_s4_2_1024_8_64', 7, (4,)), ('INIT', 'init7_s1_0', 7, (1,)), ('INIT', 'init7_s1_1024', 7, (1,)), ('INIT', 'init7_s1_1', 7, (1,)), ('INIT', 'init7_s3_16_1024_64', 7, (3,)), ('INIT', 'init7_s3_16_64_1024', 7, (3,)), ('INIT', 'init1_s_', 1, ()), ('INIT', 'init7_s3_16_1024_1024', 7, (3,)), ('INIT', 'init7_s2_32_32', 7, (2,))]
Optimization and Verification¶
Let’s try the model with a python backend (reference implementation). First step, onnx-script uses many functions. The reference evaluation expects every function to be defined so the order of functions in the model matters. No recursivity is allowed by this runtime. We need to reorder as function Rank is usually placed at the end of the model.
'dump_models/llama_onnxrt_0.onnx'
Let’s load the model and optimize them.
debug = onnx.load(model_debug)
try:
onnxrt = optimize_model_proto(onnx.load(model_onnxrt))
except ImportError as e:
print("missing library", e)
onnxrt = debug
Applied 0 pattern rewrite rules.
Let’s apply onnxruntime optimization
if ortopt:
providers = (
[("CUDAExecutionProvider", {}), ("CPUExecutionProvider", {})]
if use_cuda
else ["CPUExecutionProvider"]
)
with open(model_onnxrt.replace(".onnx", ".before.opt.onnx"), "wb") as f:
f.write(onnxrt.SerializeToString())
print(f"run onnxruntime optimization on {model_onnxrt}")
optimized = model_onnxrt.replace(".onnx", ".opt.onnx")
ort_optimize(onnxrt, output=optimized, providers=providers)
onnxrt = onnx.load(optimized)
print(f"run onnxruntime optimization on {model_debug}")
optimized = model_debug.replace(".onnx", ".opt.onnx")
ort_optimize(debug, output=optimized, disable_aot=True, providers=providers)
debug = onnx.load(optimized)
run onnxruntime optimization on dump_models/llama_onnxrt_0.onnx
run onnxruntime optimization on dump_models/llama_debug_0.onnx
For what’s following, we need to build two lists of matching inputs.
print("build_matching_inputs")
feedsrt = build_matching_inputs(model_debug, feeds, model_onnxrt)
print("done")
build_matching_inputs
done
We check both models are running.
out_onnxrt = ExtendedReferenceEvaluator(onnxrt).run(None, feedsrt)
out_debug = ExtendedReferenceEvaluator(debug).run(None, feeds)
assert out_onnxrt
assert out_debug
# assert_all_close(out_onnxrt, out_debug)
Side by side
[compare_onnx_execution] execute with 2 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 103 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 79 results
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 108 pairs
[compare_onnx_execution] done
001 = | INITIA int64 1:2 USAA ortshared_7_1_2_0_token_175 | INITIA int64 1:2 USAA ortshared_7_1_2_0_token_99
002 - | INITIA int64 1:4 CIKM ortshared_7_1_4_1_token_171 |
003 - | INITIA int64 1:1 KAAA ortshared_7_1_1_5_token_180 |
004 ~ | INITIA int64 1:3 QKMA ortshared_7_1_3_0_token_167 | INITIA int64 1:3 CKSA ortshared_7_1_3_0_token_98
005 ~ | INITIA int64 1:3 QMKA ortshared_7_1_3_1_token_168 | INITIA int64 1:4 CKIM ortshared_7_1_4_0_token_100
006 = | INITIA int64 1:1 AAAA ortshared_7_1_1_3_token_169 | INITIA int64 1:1 AAAA ortshared_7_1_1_0_token_97
007 ~ | INITIA int64 1:2 GGAA splits | INITIA int64 1:1 KAAA ortshared_7_1_1_2_token_106
008 = | INITIA int64 1:1 BAAA ortshared_7_1_1_1_token_163 | INITIA int64 1:1 BAAA ortshared_7_1_1_1_token_105
009 - | INITIA float32 IAAA ortshared_1_0_1_1_token_177 |
010 ~ | INITIA int64 1:2 GGAA splits_token_181 | INITIA int64 1:3 QKMA ortshared_7_1_3_1_token_102
011 - | INITIA int64 ZAAA ortshared_7_0_1_1_token_176 |
012 - | INITIA int64 1:4 CIKK ortshared_7_1_4_2_token_174 |
013 - | INITIA float32 BAAA ortshared_1_0_1_0_token_172 |
014 ~ | INITIA int64 1:3 QKKA ortshared_7_1_3_3_token_178 | INITIA int64 1:3 QMKA ortshared_7_1_3_3_token_107
015 - | INITIA int64 BAAA ortshared_7_0_1_0_token_164 |
016 - | INITIA int64 1:4 CKIM ortshared_7_1_4_0_token_165 |
017 ~ | INITIA int64 1:2 BKAA ortshared_7_1_2_1_token_179 | INITIA int64 1:2 GGAA ortshared_7_1_2_1_token_101
018 ~ | INITIA int64 1:3 CKSA ortshared_7_1_3_2_token_173 | INITIA int64 1:3 QKKA ortshared_7_1_3_2_token_103
019 = | INPUT float32 2:512x512 UCQB primals_4 | INPUT float32 2:512x512 UCQB input0
020 = | INPUT float32 2:512x512 URYC primals_1 | INPUT float32 2:512x512 URYC input1
021 - | INPUT float32 3:2x1024x512 YWBT primals_7 |
022 = | INPUT float32 2:512x512 VBXD primals_2 | INPUT float32 2:512x512 VBXD input2
023 = | INPUT float32 2:512x512 AUCY primals_3 | INPUT float32 2:512x512 AUCY input3
024 = | INPUT float32 2:2048x64 MDRB primals_5 | INPUT float32 2:2048x64 MDRB input4
025 = | INPUT float32 2:2048x64 ZHDU primals_6 | INPUT float32 2:2048x64 ZHDU input5
026 + | | INPUT float32 3:2x1024x512 YWBT input6
027 = | INPUT int64 2:1x1024 KAQG primals_8 | INPUT int64 2:1x1024 KAQG input7
028 = | INPUT float32 4:2x1x1024x1024 AAAA primals_9 | INPUT float32 4:2x1x1024x1024 AAAA input8
029 - | RESULT float32 2:512x512 UCQB Identity t_6 |
030 - | RESULT float32 4:2x1x1024x1024 AAAA Mul _inlfunc_aten_add|folded_2_other |
031 - | RESULT int64 2:1x1024 KAQG Expand _val_65 |
032 - | RESULT int64 3:1x1024x1 KAQG Unsqueeze _val_67 |
033 - | RESULT int64 3:1x1024x1 KAQG Concat _val_68 |
034 = | RESULT float32 2:1024x64 GSEC Slice slice_2 | RESULT float32 2:1024x64 GSEC Slice slice_2
035 - | RESULT float32 2:1024x64 GSEC Transpose _val_62 |
036 ~ | RESULT float32 3:1x1024x64 GSEC GatherND _val_69 | RESULT float32 3:1x1024x64 GSEC Gather index_1
037 = | RESULT float32 4:1x1x1024x64 GSEC Unsqueeze aten_unsqueeze_116_n2 | RESULT float32 4:1x1x1024x64 GSEC Unsqueeze output_5
038 = | RESULT float32 4:1x1024x1x64 GSEC Transpose Transpose_token_5_out0 | RESULT float32 4:1x1024x1x64 GSEC Transpose Transpose_token_4_out0
039 = | RESULT float32 2:2048x512 YWBT Reshape view | RESULT float32 2:2048x512 YWBT Reshape output_2
040 ~ | RESULT float32 2:2048x512 FVMX FusedMatMul mm_1 | RESULT float32 2:2048x512 XOOY Gemm mm_1
041 - | RESULT float32 3:2x1024x512 FVMX Reshape view_3 |
042 ~ | RESULT float32 4:2x1024x8x64 FVMX Reshape view_7 | RESULT float32 4:2x1024x8x64 XOOY Reshape view_7
043 ~ | RESULT float32 4:2x1024x8x32 AMAV Split Slice_178 | RESULT float32 4:2x1024x8x32 KQNE Split SlicesSplitPattern--slice_Tensor
044 ~ | RESULT float32 4:2x1024x8x32 GKNC Split Slice_195 | RESULT float32 4:2x1024x8x32 NZBV Split SlicesSplitPattern--slice_Tensor
045 ~ | RESULT float32 4:2x1024x8x32 UQNY Neg aten_neg_199_n0 | RESULT float32 4:2x1024x8x32 NBZF Neg neg2
046 ~ | RESULT float32 4:2x1024x8x64 VCMT Concat aten_cat_204_n0 | RESULT float32 4:2x1024x8x64 XRMK Concat cat2
047 ~ | RESULT float32 4:2x1024x8x64 NKQX Mul aten_mul_208_n0 | RESULT float32 4:2x1024x8x64 PBCM Mul mul4
048 = | RESULT float32 2:1024x64 CJYF Slice slice_1 | RESULT float32 2:1024x64 CJYF Slice slice_1
049 - | RESULT float32 2:1024x64 CJYF Transpose _val_53 |
050 ~ | RESULT float32 3:1x1024x64 CJYF GatherND _val_60 | RESULT float32 3:1x1024x64 CJYF Gather index
051 = | RESULT float32 4:1x1x1024x64 CJYF Unsqueeze aten_unsqueeze_115_n2 | RESULT float32 4:1x1x1024x64 CJYF Unsqueeze output_4
052 = | RESULT float32 4:1x1024x1x64 CJYF Transpose Transpose_token_8_out0 | RESULT float32 4:1x1024x1x64 CJYF Transpose Transpose_token_6_out0
053 ~ | RESULT float32 4:2x1024x8x64 ALUG Mul aten_mul_161_n0 | RESULT float32 4:2x1024x8x64 KMQI Mul mul3
054 ~ | RESULT float32 4:2x1024x8x64 MVKD Add _inlfunc_aten_add|folded_1_n3 | RESULT float32 4:2x1024x8x64 YOSU Add add_Tensor2
055 ~ | RESULT float32 4:2x8x64x1024 FCAN Transpose transpose_3 | RESULT float32 4:2x8x64x1024 KCYO Transpose transpose_3
056 - | RESULT float32 3:16x64x1024 FCAN Reshape view_10 |
057 - | RESULT float32 4:1x1x1024x64 GSEC Transpose unsqueeze_1 |
058 ~ | RESULT float32 2:2048x512 XOOY FusedMatMul mm | RESULT float32 2:2048x512 YWBT Reshape output_1
059 ~ | RESULT float32 3:2x1024x512 XOOY Reshape view_1 | RESULT float32 2:2048x512 LECY Gemm mm
060 ~ | RESULT float32 4:2x1024x8x64 XOOY Reshape view_6 | RESULT float32 4:2x1024x8x64 LECY Reshape view_6
061 ~ | RESULT float32 4:2x8x1024x64 DJQW Transpose transpose | RESULT float32 4:2x8x1024x64 JGZB Transpose transpose
062 ~ | RESULT float32 4:2x8x1024x32 ZBVW Split slice_3 | RESULT float32 4:2x8x1024x32 VEUG Split slice_3
063 ~ | RESULT float32 4:2x8x1024x32 DIVA Split slice_4 | RESULT float32 4:2x8x1024x32 NCGV Split slice_4
064 ~ | RESULT float32 4:2x8x1024x32 XSFA Neg neg | RESULT float32 4:2x8x1024x32 NYUF Neg neg
065 ~ | RESULT float32 4:2x8x1024x64 VSBW Concat cat | RESULT float32 4:2x8x1024x64 IDPL Concat cat
066 ~ | RESULT float32 4:2x8x1024x64 OCLC Mul mul_1 | RESULT float32 4:2x8x1024x64 YARS Mul mul_1
067 - | RESULT float32 4:1x1x1024x64 CJYF Transpose unsqueeze |
068 ~ | RESULT float32 4:2x8x1024x64 VANM Mul mul | RESULT float32 4:2x8x1024x64 OTPY Mul mul
069 ~ | RESULT float32 4:2x8x1024x64 KCYO Add add | RESULT float32 4:2x8x1024x64 NSFP Add add
070 - | RESULT float32 3:16x1024x64 KCYO Reshape view_9 |
071 - | RESULT float32 3:16x1024x1024 MSON MatMul bmm |
072 ~ | RESULT float32 4:2x8x1024x1024 MSON Reshape view_11 | RESULT float32 4:2x8x1024x1024 QPIE FusedMatMul div
073 - | RESULT float32 4:2x8x1024x1024 YGCR Div div |
074 ~ | RESULT float32 4:2x8x1024x1024 YGCR Add add_2 | RESULT float32 4:2x8x1024x1024 QPIE Add add_2
075 ~ | RESULT float32 4:2x8x1024x1024 ONNN Softmax _softmax | RESULT float32 4:2x8x1024x1024 ONNO Softmax output_8
076 - | RESULT float32 3:16x1024x1024 ONNN Reshape view_12 |
077 ~ | RESULT float32 2:2048x512 MQUP FusedMatMul mm_2 | RESULT float32 2:2048x512 YWBT Reshape output_3
078 ~ | RESULT float32 3:2x1024x512 MQUP Reshape view_5 | RESULT float32 2:2048x512 FVMX Gemm mm_2
079 ~ | RESULT float32 4:2x1024x8x64 MQUP Reshape view_8 | RESULT float32 4:2x1024x8x64 FVMX Reshape view_8
080 ~ | RESULT float32 4:2x8x1024x64 IUHD Transpose transpose_2 | RESULT float32 4:2x8x1024x64 ZCMY Transpose transpose_2
081 ~ | RESULT float32 3:16x1024x64 IUHD Reshape view_13 | RESULT float32 4:2x8x1024x64 VPTA MatMul view_11
082 ~ | RESULT float32 3:16x1024x64 IUNZ MatMul bmm_1 | RESULT float32 4:2x1024x8x64 FFHL Transpose transpose_4
083 ~ | RESULT float32 4:2x8x1024x64 IUNZ Reshape view_14 | RESULT float32 2:2048x512 FFHL Reshape output_12
084 ~ | RESULT float32 4:2x1024x8x64 SKKC Transpose transpose_4 | RESULT float32 2:2048x512 GDEI Gemm mm_3
085 ~ | RESULT float32 3:2x1024x512 SKKC Reshape view_15 | RESULT float32 3:2x1024x512 GDEI Reshape output_0
086 + | | RESULT float32 2:512x512 CXYY Transpose output_11
087 ~ | RESULT float32 2:2048x512 SKKC Reshape view_16 | RESULT float32 3:16x1024x64 ZCMY Reshape output_10
088 - | RESULT float32 2:2048x512 FJWU FusedMatMul mm_3 |
089 - | RESULT float32 3:2x1024x512 FJWU Reshape view_17 |
090 ~ | RESULT float32 3:16x1024x1024 ONNN Transpose transpose_6 | RESULT float32 3:16x1024x1024 ONNO Reshape output_9
091 + | | RESULT float32 3:16x64x1024 KCYO Reshape output_7
092 - | RESULT float32 4:2x8x1024x1024 ONNN Identity detach_3 |
093 ~ | RESULT float32 3:16x1024x64 FCAN Transpose transpose_9 | RESULT float32 3:16x1024x64 NSFP Reshape output_6
094 + | | OUTPUT float32 3:2x1024x512 GDEI output_0
095 ~ | RESULT float32 3:16x64x1024 KCYO Transpose transpose_8 | OUTPUT float32 2:2048x512 YWBT output_1
096 ~ | RESULT float32 3:16x64x1024 IUHD Transpose transpose_7 | OUTPUT float32 2:2048x512 YWBT output_2
097 = | OUTPUT float32 2:2048x512 YWBT view | OUTPUT float32 2:2048x512 YWBT output_3
098 - | OUTPUT float32 2:512x512 UCQB t_6 |
099 = | OUTPUT float32 4:1x1x1024x64 CJYF unsqueeze | OUTPUT float32 4:1x1x1024x64 CJYF output_4
100 = | OUTPUT float32 4:1x1x1024x64 GSEC unsqueeze_1 | OUTPUT float32 4:1x1x1024x64 GSEC output_5
101 ~ | OUTPUT float32 3:16x64x1024 IUHD transpose_7 | OUTPUT float32 3:16x1024x64 NSFP output_6
102 = | OUTPUT float32 3:16x64x1024 KCYO transpose_8 | OUTPUT float32 3:16x64x1024 KCYO output_7
103 - | OUTPUT float32 3:16x1024x64 FCAN transpose_9 |
104 ~ | OUTPUT float32 4:2x8x1024x1024 ONNN detach_3 | OUTPUT float32 4:2x8x1024x1024 ONNO output_8
105 ~ | OUTPUT float32 3:16x1024x1024 ONNN transpose_6 | OUTPUT float32 3:16x1024x1024 ONNO output_9
106 ~ | OUTPUT float32 2:2048x512 SKKC view_16 | OUTPUT float32 3:16x1024x64 ZCMY output_10
107 + | | OUTPUT float32 2:512x512 CXYY output_11
108 ~ | OUTPUT float32 3:2x1024x512 FJWU view_17 | OUTPUT float32 2:2048x512 FFHL output_12
Total running time of the script: (0 minutes 6.988 seconds)