Note
Go to the end to download the full example code.
301: Compares LLAMA exporters for onnxrt backend¶
The script compares exported models in pytorch using onnxrt backend. It tries to do a side by side of the execution of both models.
To run the script:
python _doc/examples/plot_llama_diff_dort --help
The following example compares the forward step for mixed precision on cuda and produces all the intermediate onnx graphs.
You may use --mixed=1
to compare the backward graphs.
Some helpers¶
from experimental_experiment.args import get_parsed_args
script_args = get_parsed_args(
"plot_llama_diff_export",
description=__doc__,
part=("attention", "one value among attention, decoder, model"),
ortopt=(1, "run onnxruntime optimization"),
backward=(0, "does one operator for backward"),
cuda=(0, "use cuda or not"),
mixed=(0, "use miwed precision"),
opset=(18, "onnx opset"),
expose="part,exporter,ortopt,cuda,mixed,opset",
)
import copy
import os
import warnings
import logging
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import onnxruntime
has_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers()
except ImportError:
print("onnxruntime not available.")
import sys
sys.exit(0)
import onnx
from onnx_array_api.reference import compare_onnx_execution, ExtendedReferenceEvaluator
import torch
from torch._dynamo.backends.common import aot_autograd
from experimental_experiment.ext_test_case import unit_test_going
from experimental_experiment.convert.convert_helper import (
optimize_model_proto_oxs,
ort_optimize,
)
from experimental_experiment.torch_models.llama_helper import (
get_llama_model,
get_llama_attention,
get_llama_decoder,
)
from experimental_experiment.torch_models.dump_helper import (
assert_all_close,
dump_onnx,
reorder_functions_in_proto,
inputs_from_onnx_model,
build_matching_inputs,
results_to_string,
)
from experimental_experiment.torch_models.training_helper import (
train_loop,
make_aot_ort,
)
from experimental_experiment.torch_dynamo import (
onnx_debug_backend,
get_decomposition_table,
)
has_cuda = has_cuda and torch.cuda.is_available()
logging.disable(logging.ERROR)
provider = "cuda" if has_cuda else "cpu"
The exporting functions¶
print(f"part={script_args.part}")
ortopt = script_args.ortopt in (1, "1")
print(f"ortopt={ortopt}")
backward = script_args.backward in (1, "1")
print(f"backward={backward}")
use_cuda = script_args.cuda in (1, "1")
print(f"cuda={use_cuda}")
use_mixed = script_args.mixed in (1, "1")
print(f"mixed={use_mixed}")
opset = int(script_args.opset)
print(f"opset={opset}")
part=attention
ortopt=True
backward=False
cuda=False
mixed=False
opset=18
Model and data¶
if unit_test_going():
kwargs = dict(input_dims=[(2, 1024)] * 2)
else:
kwargs = dict(
input_dims=[(2, 1024)] * 2,
_attn_implementation="eager",
num_hidden_layers=1,
hidden_size=512,
vocab_size=4000,
intermediate_size=2000,
max_position_embeddings=2048,
num_attention_heads=8,
)
if script_args.part == "attention":
model, inputs = get_llama_attention(**kwargs)
elif script_args.part == "decoder":
model, inputs = get_llama_decoder(**kwargs)
elif script_args.part == "model":
model, inputs = get_llama_model(**kwargs)
else:
raise RuntimeError(f"Unexpected value for part={script_args.part!r}")
if use_cuda:
model = model.to("cuda")
inputs = [[i.to("cuda") for i in inp] for inp in inputs]
print(f"simple run with {len(inputs)} inputs")
if backward:
if use_mixed:
assert use_cuda, "mixed precision only works with cuda"
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch.cuda.synchronize()
expected = train_loop(copy.deepcopy(model), *inputs[0])
torch.cuda.synchronize()
else:
expected = train_loop(copy.deepcopy(model), *inputs[0])
print(
f"-- eager mode worked, {len(expected)} gradients, first one is "
f"{expected[0].shape}, {expected[0].dtype}"
)
else:
if use_mixed:
assert use_cuda, "mixed precision only works with cuda"
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch.cuda.synchronize()
expected = model(*inputs[0])
torch.cuda.synchronize()
else:
expected = model(*inputs[0])
print(results_to_string(expected))
simple run with 2 inputs
torch.float32 (2, 1024, 512) [sum=-453]
Exporting¶
if hasattr(torch._dynamo.variables.misc, "LoggingLoggerVariable"):
# A tweak to make torch.export.export work.
torch._dynamo.variables.misc.LoggingLoggerVariable.call_method = lambda *_, **__: None
folder = "dump_models"
storage = {}
if backward:
# onnxrt backend
local_aot_ort, _ = make_aot_ort(dynamic=False, rewrite=True)
optimized_mod = torch.compile(
copy.deepcopy(model), backend=local_aot_ort, dynamic=False, fullgraph=True
)
with dump_onnx("llama_onnxrt", folder=folder, clean=True):
if use_mixed:
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch.cuda.synchronize()
expected_onnxrt = train_loop(optimized_mod, *inputs[0])
torch.cuda.synchronize()
else:
expected_onnxrt = train_loop(optimized_mod, *inputs[0])
assert_all_close(expected[0], expected_onnxrt[0], atol=1e-3)
print(
f"-- onnxrt backend worked, {len(expected_onnxrt)} gradients, first one is "
f"{expected_onnxrt[0].shape}, {expected_onnxrt[0].dtype}"
)
# debugging backend
aot_compiler = aot_autograd(
fw_compiler=lambda *args, **kwargs: onnx_debug_backend(
*args,
dump_prefix=os.path.join(folder, "llama_debug"),
target_opset=opset,
storage=storage,
**kwargs,
),
decompositions=get_decomposition_table(),
)
onnx_mod = torch.compile(copy.deepcopy(model), backend=aot_compiler, fullgraph=True)
if use_mixed:
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch.cuda.synchronize()
got = train_loop(onnx_mod, *inputs[0])
torch.cuda.synchronize()
else:
got = train_loop(onnx_mod, *inputs[0])
assert_all_close(expected[0], got[0], atol=1e-2 if use_mixed else 1e-4)
print(
f"-- debug backend worked, {len(got)} gradients, first one is "
f"{got[0].shape}, {got[0].dtype}"
)
else:
# onnxrt backend
local_aot_ort, _ = make_aot_ort(dynamic=True, rewrite=True)
optimized_mod = torch.compile(model, backend=local_aot_ort, fullgraph=True)
with dump_onnx("llama_onnxrt", folder=folder, clean=True):
if use_mixed:
with torch.autocast(device_type="cuda", dtype=torch.float16):
torch.cuda.synchronize()
expected_onnxrt = optimized_mod(*inputs[0])
torch.cuda.synchronize()
else:
expected_onnxrt = optimized_mod(*inputs[0])
assert_all_close(expected, expected_onnxrt, atol=1e-2)
# debugging backend
aot_compiler = aot_autograd(
fw_compiler=lambda *args, **kwargs: onnx_debug_backend(
*args,
dump_prefix=os.path.join(folder, "llama_debug"),
target_opset=17,
storage=storage,
**kwargs,
)
)
onnx_mod = torch.compile(model, backend=aot_compiler, fullgraph=True)
if use_mixed:
with torch.autocast(device_type="cuda", dtype=torch.float16):
got = onnx_mod(*inputs[0])
else:
try:
got = onnx_mod(*inputs[0])
except Exception as e:
print(f"ERROR: {e}")
got = None
if got is not None:
assert_all_close(expected, got, atol=1 if use_mixed else 1e-3)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/onnx/_internal/_exporter_legacy.py:101: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
warnings.warn(
Applied 9 of general pattern rewrite rules.
Applied 1 of general pattern rewrite rules.
For forward, there are two files, one onnx model and the graph module printed in a txt file. For backward, there are two onnx models. Then it is multiplied by the number of backends.
models = os.listdir(folder)
print(f"exported models: {models}")
exported models: ['llama_debug_0.onnx', 'llama_onnxrt_0.txt', 'llama_debug_0.txt', 'llama_onnxrt_0.onnx']
Inputs used by the debug backend
-- input0 float32 (2, 1024, 512)
-- input1 float32 (512, 512)
-- input2 float32 (512, 512)
-- input3 float32 (512, 512)
-- input4 float32 (32,)
-- input5 int64 (1, 1024)
-- input6 float32 (2, 1, 1024, 1024)
-- input7 float32 (512, 512)
Let’s the first line of the graph module
if "instance" in storage:
graph_module = storage["instance"][0]["graph_module"]
print("\n".join(str(graph_module.graph).split("\n")[:10]))
graph():
%primals_1 : [num_users=3] = placeholder[target=primals_1]
%primals_2 : [num_users=1] = placeholder[target=primals_2]
%primals_3 : [num_users=1] = placeholder[target=primals_3]
%primals_4 : [num_users=1] = placeholder[target=primals_4]
%primals_5 : [num_users=1] = placeholder[target=primals_5]
%primals_6 : [num_users=1] = placeholder[target=primals_6]
%primals_7 : [num_users=1] = placeholder[target=primals_7]
%primals_8 : [num_users=1] = placeholder[target=primals_8]
%t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%primals_2,), kwargs = {})
Comparison and execution¶
if "instance" in storage:
if backward:
print(f"-- {len(storage['instance'])} onnx models were creates")
for i, inst in enumerate(storage["instance"]):
print(f" model {i}: {len(inst['inputs'])} runs")
# deal with backward
onnx_models = list(sorted([m for m in models if m.endswith(".onnx")]))
assert len(onnx_models) == 4, f"unexpected value {onnx_models}"
onnx_models = list(sorted([m for m in models if m.endswith(".onnx") and "_1" in m]))
assert len(onnx_models) == 2, f"unexpected value {onnx_models}"
model_onnxrt = os.path.join(folder, onnx_models[1])
model_debug = os.path.join(folder, onnx_models[0])
else:
onnx_models = list(sorted([m for m in models if m.endswith(".onnx")]))
if len(onnx_models) == 2:
model_onnxrt = os.path.join(folder, onnx_models[1])
model_debug = os.path.join(folder, onnx_models[0])
else:
model_debug = os.path.join(folder, onnx_models[0])
# the following error may appear:
# Node type 'Rank' from domain 'pkg.onnxscript.torch_lib.common' is unknown
print(f"One model is missing, onnx_models={onnx_models}")
model_onnxrt = model_debug
print(f"model_onnxrt={model_onnxrt}")
print(f"model_debug={model_debug}")
model_onnxrt=dump_models/llama_onnxrt_0.onnx
model_debug=dump_models/llama_debug_0.onnx
The inputs of both models
if "instance" in storage:
print("onnxrt:", inputs_from_onnx_model(model_onnxrt))
print("debug:", inputs_from_onnx_model(model_debug))
onnxrt: [('INPUT', 'primals_2', 1, (512, 512)), ('INPUT', 'primals_1', 1, (2, 1024, 512)), ('INPUT', 'primals_3', 1, (512, 512)), ('INPUT', 'primals_4', 1, (512, 512)), ('INPUT', 'primals_5', 1, (32,)), ('INPUT', 'primals_6', 7, (1, 1024)), ('INPUT', 'primals_7', 1, (2, 1, 1024, 1024)), ('INPUT', 'primals_8', 1, (512, 512))]
debug: [('INPUT', 'input0', 1, (2, 1024, 512)), ('INPUT', 'input1', 1, (512, 512)), ('INPUT', 'input2', 1, (512, 512)), ('INPUT', 'input3', 1, (512, 512)), ('INPUT', 'input4', 1, (32,)), ('INPUT', 'input5', 7, (1, 1024)), ('INPUT', 'input6', 1, (2, 1, 1024, 1024)), ('INPUT', 'input7', 1, (512, 512))]
Inputs are not the same. The first model has more and some inputs were moved into the initializer list into for model_debug.
if "instance" in storage:
print("debug:", inputs_from_onnx_model(model_debug, init=True))
debug: [('INPUT', 'input0', 1, (2, 1024, 512)), ('INPUT', 'input1', 1, (512, 512)), ('INPUT', 'input2', 1, (512, 512)), ('INPUT', 'input3', 1, (512, 512)), ('INPUT', 'input4', 1, (32,)), ('INPUT', 'input5', 7, (1, 1024)), ('INPUT', 'input6', 1, (2, 1, 1024, 1024)), ('INPUT', 'input7', 1, (512, 512)), ('INIT', 'init7_s2_2048_512', 7, (2,)), ('INIT', 'init7_s3_2_1024_512', 7, (3,)), ('INIT', 'init7_s4_2_1024_-1_64', 7, (4,)), ('INIT', 'init7_s1_1', 7, (1,)), ('INIT', 'init1_s_', 1, ()), ('INIT', 'init7_s3_16_1024_64', 7, (3,)), ('INIT', 'init7_s3_16_64_1024', 7, (3,)), ('INIT', 'init1_s_2', 1, ()), ('INIT', 'init7_s3_16_1024_1024', 7, (3,)), ('INIT', 'init7_s2_0_2', 7, (2,)), ('INIT', 'init7_s2_32_32', 7, (2,))]
Optimization and Verification¶
Let’s try the model with a python backend (reference implementation). First step, onnxscript uses many functions. The reference evaluation expects every function to be defined so the order of functions in the model matters. No recursivity is allowed by this runtime. We need to reorder as function Rank is usually placed at the end of the model.
if "instance" in storage:
reorder_functions_in_proto(model_onnxrt)
Let’s load the model and optimize them.
if "instance" in storage:
debug = onnx.load(model_debug)
try:
onnxrt = optimize_model_proto_oxs(onnx.load(model_onnxrt))
except ImportError as e:
print("missing library", e)
onnxrt = debug
Let’s apply onnxruntime optimization
if "instance" in storage and ortopt:
providers = (
[("CUDAExecutionProvider", {}), ("CPUExecutionProvider", {})]
if use_cuda
else ["CPUExecutionProvider"]
)
with open(model_onnxrt.replace(".onnx", ".before.opt.onnx"), "wb") as f:
f.write(onnxrt.SerializeToString())
print(f"run onnxruntime optimization on {model_onnxrt}")
optimized = model_onnxrt.replace(".onnx", ".opt.onnx")
ort_optimize(onnxrt, output=optimized, providers=providers)
onnxrt = onnx.load(optimized)
print(f"run onnxruntime optimization on {model_debug}")
optimized = model_debug.replace(".onnx", ".opt.onnx")
ort_optimize(debug, output=optimized, disable_aot=True, providers=providers)
debug = onnx.load(optimized)
run onnxruntime optimization on dump_models/llama_onnxrt_0.onnx
run onnxruntime optimization on dump_models/llama_debug_0.onnx
For what’s following, we need to build two lists of matching inputs.
if "instance" in storage:
print("build_matching_inputs")
feedsrt = build_matching_inputs(model_debug, feeds, model_onnxrt)
print("done")
build_matching_inputs
done
We check both models are running.
if "instance" in storage:
out_onnxrt = ExtendedReferenceEvaluator(onnxrt).run(None, feedsrt)
out_debug = ExtendedReferenceEvaluator(debug).run(None, feeds)
assert out_onnxrt
assert out_debug
# assert_all_close(out_onnxrt, out_debug)
Side by side
[compare_onnx_execution] execute with 2 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 94 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 94 results (first model)
[compare_onnx_execution] got 82 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 105 pairs
[compare_onnx_execution] done
001 = | INITIA int64 1:2 USAA _val_301 | INITIA int64 1:2 USAA init7_s2_2048_512
002 - | INITIA int64 AAAA aten_unsqueeze_75_dim_0 |
003 ~ | INITIA int64 1:4 CIKK _val_274 | INITIA int64 1:3 CKSA init7_s3_2_1024_512
004 - | INITIA int64 BAAA aten_unsqueeze_311_dim_0 |
005 - | INITIA int64 1:3 QMKA _val_269 |
006 ~ | INITIA int64 1:3 CKSA _val_96 | INITIA int64 1:4 CKZM init7_s4_2_1024_-1_64
007 - | INITIA int64 CAAA aten_unsqueeze_159_dim_0 |
008 ~ | INITIA int64 1:3 QKMA _val_264 | INITIA int64 1:1 BAAA init7_s1_1
009 - | INITIA int64 1:4 CKZM _val_137 |
010 ~ | INITIA int64 1:2 GGAA splits | INITIA int64 1:2 ACAA init7_s2_0_2
011 - | INITIA float32 IAAA _val_276 |
012 ~ | INITIA int64 1:3 CKZA _val_298 | INITIA int64 1:3 QKMA init7_s3_16_1024_64
013 ~ | INITIA int64 1:3 QKKA _val_287 | INITIA int64 1:3 QMKA init7_s3_16_64_1024
014 = | INITIA int64 1:2 GGAA splits_token_9 | INITIA int64 1:2 GGAA init7_s2_32_32
015 ~ | INITIA int64 1:4 CIKM _val_293 | INITIA int64 1:3 QKKA init7_s3_16_1024_1024
016 + | | INPUT float32 3:2x1024x512 PYUE input0
017 = | INPUT float32 2:512x512 XZAG primals_2 | INPUT float32 2:512x512 XZAG input1
018 - | INPUT float32 3:2x1024x512 PYUE primals_1 |
019 = | INPUT float32 2:512x512 PVAD primals_3 | INPUT float32 2:512x512 PVAD input2
020 = | INPUT float32 2:512x512 GBBF primals_4 | INPUT float32 2:512x512 GBBF input3
021 = | INPUT float32 1:32 DAAA primals_5 | INPUT float32 1:32 DAAA input4
022 = | INPUT int64 2:1x1024 KAQG primals_6 | INPUT int64 2:1x1024 KAQG input5
023 = | INPUT float32 4:2x1x1024x1024 AAAA primals_7 | INPUT float32 4:2x1x1024x1024 AAAA input6
024 = | INPUT float32 2:512x512 CTSV primals_8 | INPUT float32 2:512x512 CTSV input7
025 - | RESULT float32 2:512x512 CTSV Identity t_6 |
026 = | RESULT int64 3:1x1x1024 KAQG Unsqueeze unsqueeze_2 | RESULT int64 3:1x1x1024 KAQG Unsqueeze unsqueeze_2
027 = | RESULT float32 3:1x1x1024 KAQG Cast _to_copy | RESULT float32 3:1x1x1024 KAQG Cast _to_copy
028 - | RESULT float32 2:1x32 DAAA Unsqueeze unsqueeze |
029 = | RESULT float32 3:1x32x1 DAAA Unsqueeze unsqueeze_1 | RESULT float32 3:1x32x1 DAAA Unsqueeze unsqueeze_1
030 + | | RESULT float32 3:1x32x1024 EFXM MatMul bmm
031 - | RESULT float32 3:1x1024x32 XCHM FusedMatMul transpose_3 |
032 - | RESULT float32 3:1x1024x64 VFPY Concat cat |
033 ~ | RESULT float32 3:1x1024x64 GSEC Sin sin | RESULT float32 3:1x64x1024 JKJK Concat cat_token_5
034 ~ | RESULT float32 4:1x1x1024x64 GSEC Unsqueeze unsqueeze_4 | RESULT float32 3:1x64x1024 RMRM Sin sin_token_7
035 + | | RESULT float32 4:1x1x64x1024 RMRM Unsqueeze Opset8
036 = | RESULT float32 4:1x1024x1x64 GSEC Transpose Transpose_token_4_out0 | RESULT float32 4:1x1024x1x64 GSEC Transpose Transpose_token_10_out0
037 = | RESULT float32 2:2048x512 PYUE Reshape view | RESULT float32 2:2048x512 PYUE Reshape output_2
038 ~ | RESULT float32 2:2048x512 SUEW FusedMatMul mm_1 | RESULT float32 2:2048x512 SUEW Gemm mm_1
039 - | RESULT float32 3:2x1024x512 SUEW Reshape _unsafe_view_1 |
040 = | RESULT float32 4:2x1024x8x64 SUEW Reshape view_4 | RESULT float32 4:2x1024x8x64 SUEW Reshape view_4
041 = | RESULT float32 4:2x1024x8x32 UFUT Split Slice_263 | RESULT float32 4:2x1024x8x32 UFUT Split SlicesSplitPattern--slice_Tensor
042 = | RESULT float32 4:2x1024x8x32 ZOKD Split Slice_280 | RESULT float32 4:2x1024x8x32 ZOKD Split SlicesSplitPattern--slice_Tensor
043 = | RESULT float32 4:2x1024x8x32 BMQX Neg Neg_290 | RESULT float32 4:2x1024x8x32 BMQX Neg neg2
044 = | RESULT float32 4:2x1024x8x64 URKR Concat Concat_294 | RESULT float32 4:2x1024x8x64 URKR Concat cat3
045 = | RESULT float32 4:2x1024x8x64 XKLR Mul Mul_315 | RESULT float32 4:2x1024x8x64 XKLR Mul mul_Tensor10
046 + | | RESULT float32 3:1x64x1024 NHNH Cos cos_token_13
047 ~ | RESULT float32 3:1x1024x64 CJYF Cos cos | RESULT float32 4:1x1x64x1024 NHNH Unsqueeze Opset7
048 - | RESULT float32 4:1x1x1024x64 CJYF Unsqueeze unsqueeze_3 |
049 = | RESULT float32 4:1x1024x1x64 CJYF Transpose Transpose_token_6_out0 | RESULT float32 4:1x1024x1x64 CJYF Transpose Transpose_token_16_out0
050 = | RESULT float32 4:2x1024x8x64 QRHY Mul Mul_313 | RESULT float32 4:2x1024x8x64 QRHY Mul mul_Tensor9
051 = | RESULT float32 4:2x1024x8x64 NBTQ Add Add_317 | RESULT float32 4:2x1024x8x64 NBTQ Add add_Tensor2
052 = | RESULT float32 4:2x8x64x1024 MCDE Transpose transpose_4 | RESULT float32 4:2x8x64x1024 MCDE Transpose transpose_4
053 + | | RESULT float32 4:1x1x1024x64 GSEC Transpose output_5
054 - | RESULT float32 3:16x64x1024 MCDE Reshape _unsafe_view_4 |
055 ~ | RESULT float32 2:2048x512 MVCJ FusedMatMul mm | RESULT float32 2:2048x512 PYUE Reshape output_1
056 ~ | RESULT float32 3:2x1024x512 MVCJ Reshape _unsafe_view | RESULT float32 2:2048x512 MVCJ Gemm mm
057 = | RESULT float32 4:2x1024x8x64 MVCJ Reshape view_3 | RESULT float32 4:2x1024x8x64 MVCJ Reshape view_3
058 = | RESULT float32 4:2x8x1024x64 RQST Transpose transpose | RESULT float32 4:2x8x1024x64 RQST Transpose transpose
059 = | RESULT float32 4:2x8x1024x32 KJGH Split slice_4 | RESULT float32 4:2x8x1024x32 KJGH Split slice_4
060 = | RESULT float32 4:2x8x1024x32 HHML Split slice_5 | RESULT float32 4:2x8x1024x32 HHML Split slice_5
061 = | RESULT float32 4:2x8x1024x32 TTOP Neg neg | RESULT float32 4:2x8x1024x32 TTOP Neg neg
062 = | RESULT float32 4:2x8x1024x64 CDUW Concat cat_1 | RESULT float32 4:2x8x1024x64 CDUW Concat cat_1
063 = | RESULT float32 4:2x8x1024x64 NUIG Mul mul_3 | RESULT float32 4:2x8x1024x64 NUIG Mul mul_3
064 + | | RESULT float32 4:1x1x1024x64 CJYF Transpose output_4
065 = | RESULT float32 4:2x8x1024x64 OUMI Mul mul_2 | RESULT float32 4:2x8x1024x64 OUMI Mul mul_2
066 = | RESULT float32 4:2x8x1024x64 BPUO Add add | RESULT float32 4:2x8x1024x64 BPUO Add add
067 - | RESULT float32 3:16x1024x64 BPUO Reshape _unsafe_view_3 |
068 - | RESULT float32 3:16x1024x1024 CTYN MatMul bmm_1 |
069 - | RESULT float32 4:2x8x1024x1024 CTYN Reshape view_9 |
070 ~ | RESULT float32 4:2x8x1024x1024 EFQC Div div | RESULT float32 4:2x8x1024x1024 EFQC FusedMatMul div
071 = | RESULT float32 4:2x8x1024x1024 EFQC Add add_2 | RESULT float32 4:2x8x1024x1024 EFQC Add add_2
072 = | RESULT float32 4:2x8x1024x1024 NNNO Softmax _softmax | RESULT float32 4:2x8x1024x1024 NNNO Softmax output_8
073 - | RESULT float32 3:16x1024x1024 NNNO Reshape view_10 |
074 ~ | RESULT float32 2:2048x512 UUMW FusedMatMul mm_2 | RESULT float32 2:2048x512 PYUE Reshape output_3
075 ~ | RESULT float32 3:2x1024x512 UUMW Reshape _unsafe_view_2 | RESULT float32 2:2048x512 UUMW Gemm mm_2
076 = | RESULT float32 4:2x1024x8x64 UUMW Reshape view_5 | RESULT float32 4:2x1024x8x64 UUMW Reshape view_5
077 = | RESULT float32 4:2x8x1024x64 KGKZ Transpose transpose_2 | RESULT float32 4:2x8x1024x64 KGKZ Transpose transpose_2
078 ~ | RESULT float32 3:16x1024x64 KGKZ Reshape _unsafe_view_5 | RESULT float32 4:2x8x1024x64 GMGA MatMul view_11
079 ~ | RESULT float32 3:16x1024x64 GMGA MatMul bmm_2 | RESULT float32 4:2x1024x8x64 IJJX Transpose transpose_5
080 ~ | RESULT float32 4:2x8x1024x64 GMGA Reshape view_11 | RESULT float32 2:2048x512 IJJX Reshape output_12
081 ~ | RESULT float32 4:2x1024x8x64 IJJX Transpose transpose_5 | RESULT float32 2:2048x512 GHZE Gemm mm_3
082 ~ | RESULT float32 3:2x1024x512 IJJX Reshape view_12 | RESULT float32 3:2x1024x512 GHZE Reshape output_0
083 + | | RESULT float32 3:16x1024x1024 NNNO Reshape output_9
084 ~ | RESULT float32 2:2048x512 IJJX Reshape view_13 | RESULT float32 3:16x64x1024 MCDE Reshape output_7
085 ~ | RESULT float32 2:2048x512 GHZE FusedMatMul mm_3 | RESULT float32 3:16x1024x64 BPUO Reshape output_6
086 ~ | RESULT float32 3:2x1024x512 GHZE Reshape _unsafe_view_6 | RESULT float32 3:16x1024x64 KGKZ Reshape output_10
087 + | | RESULT float32 2:512x512 RYSA Transpose output_11
088 - | RESULT float32 3:16x1024x1024 NNNO Transpose transpose_7 |
089 - | RESULT float32 4:2x8x1024x1024 NNNO Identity detach_3 |
090 ~ | RESULT float32 3:16x1024x64 MCDE Transpose transpose_10 | OUTPUT float32 3:2x1024x512 GHZE output_0
091 ~ | RESULT float32 3:16x64x1024 BPUO Transpose transpose_9 | OUTPUT float32 2:2048x512 PYUE output_1
092 - | RESULT float32 3:16x64x1024 KGKZ Transpose transpose_8 |
093 = | OUTPUT float32 2:2048x512 PYUE view | OUTPUT float32 2:2048x512 PYUE output_2
094 - | OUTPUT float32 2:512x512 CTSV t_6 |
095 ~ | OUTPUT float32 3:16x64x1024 KGKZ transpose_8 | OUTPUT float32 2:2048x512 PYUE output_3
096 ~ | OUTPUT float32 3:1x1024x64 VFPY cat | OUTPUT float32 4:1x1x1024x64 CJYF output_4
097 + | | OUTPUT float32 4:1x1x1024x64 GSEC output_5
098 + | | OUTPUT float32 3:16x1024x64 BPUO output_6
099 ~ | OUTPUT float32 3:16x64x1024 BPUO transpose_9 | OUTPUT float32 3:16x64x1024 MCDE output_7
100 - | OUTPUT float32 3:16x1024x64 MCDE transpose_10 |
101 = | OUTPUT float32 4:2x8x1024x1024 NNNO detach_3 | OUTPUT float32 4:2x8x1024x1024 NNNO output_8
102 = | OUTPUT float32 3:16x1024x1024 NNNO transpose_7 | OUTPUT float32 3:16x1024x1024 NNNO output_9
103 ~ | OUTPUT float32 2:2048x512 IJJX view_13 | OUTPUT float32 3:16x1024x64 KGKZ output_10
104 + | | OUTPUT float32 2:512x512 RYSA output_11
105 ~ | OUTPUT float32 3:2x1024x512 GHZE _unsafe_view_6 | OUTPUT float32 2:2048x512 IJJX output_12
Total running time of the script: (0 minutes 6.666 seconds)
Related examples
101: A custom backend for torch
102: Fuse kernels in a small Llama Model