301: Compares LLAMA exporters for onnxrt backend

The script compares exported models in pytorch using onnxrt backend. It tries to do a side by side of the execution of both models.

To run the script:

python _doc/examples/plot_llama_diff_dort --help

The following example compares the forward step for mixed precision on cuda and produces all the intermediate onnx graphs.

python _doc/examples/plot_llama_diff_dort.py --part model --ortopt 1 \
        --cuda 1 --backward 0 --mixed 1

You may use --mixed=1 to compare the backward graphs.

Some helpers

from experimental_experiment.args import get_parsed_args

script_args = get_parsed_args(
    "plot_llama_diff_export",
    description=__doc__,
    part=("attention", "one value among attention, decoder, model"),
    ortopt=(1, "run onnxruntime optimization"),
    backward=(0, "does one operator for backward"),
    cuda=(0, "use cuda or not"),
    mixed=(0, "use miwed precision"),
    opset=(18, "onnx opset"),
    expose="part,exporter,ortopt,cuda,mixed,opset",
)


import copy
import os
import warnings
import logging

try:
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        import onnxruntime

        has_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers()
except ImportError:
    print("onnxruntime not available.")
    import sys

    sys.exit(0)

import onnx
from onnx_array_api.reference import compare_onnx_execution, ExtendedReferenceEvaluator
import torch
from torch._dynamo.backends.common import aot_autograd
from experimental_experiment.ext_test_case import unit_test_going
from experimental_experiment.convert.convert_helper import (
    optimize_model_proto_oxs,
    ort_optimize,
)
from experimental_experiment.torch_models.llama_helper import (
    get_llama_model,
    get_llama_attention,
    get_llama_decoder,
)
from experimental_experiment.torch_models.dump_helper import (
    assert_all_close,
    dump_onnx,
    reorder_functions_in_proto,
    inputs_from_onnx_model,
    build_matching_inputs,
    results_to_string,
)
from experimental_experiment.torch_models.training_helper import (
    train_loop,
    make_aot_ort,
)
from experimental_experiment.torch_dynamo import (
    onnx_debug_backend,
    get_decomposition_table,
)

has_cuda = has_cuda and torch.cuda.is_available()
logging.disable(logging.ERROR)
provider = "cuda" if has_cuda else "cpu"

The exporting functions

print(f"part={script_args.part}")
ortopt = script_args.ortopt in (1, "1")
print(f"ortopt={ortopt}")
backward = script_args.backward in (1, "1")
print(f"backward={backward}")
use_cuda = script_args.cuda in (1, "1")
print(f"cuda={use_cuda}")
use_mixed = script_args.mixed in (1, "1")
print(f"mixed={use_mixed}")
opset = int(script_args.opset)
print(f"opset={opset}")
part=attention
ortopt=True
backward=False
cuda=False
mixed=False
opset=18

Model and data

if unit_test_going():
    kwargs = dict(input_dims=[(2, 1024)] * 2)
else:
    kwargs = dict(
        input_dims=[(2, 1024)] * 2,
        _attn_implementation="eager",
        num_hidden_layers=1,
        hidden_size=512,
        vocab_size=4000,
        intermediate_size=2000,
        max_position_embeddings=2048,
        num_attention_heads=8,
    )

if script_args.part == "attention":
    model, inputs = get_llama_attention(**kwargs)
elif script_args.part == "decoder":
    model, inputs = get_llama_decoder(**kwargs)
elif script_args.part == "model":
    model, inputs = get_llama_model(**kwargs)
else:
    raise RuntimeError(f"Unexpected value for part={script_args.part!r}")

if use_cuda:
    model = model.to("cuda")
    inputs = [[i.to("cuda") for i in inp] for inp in inputs]

print(f"simple run with {len(inputs)} inputs")
if backward:
    if use_mixed:
        assert use_cuda, "mixed precision only works with cuda"
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            torch.cuda.synchronize()
            expected = train_loop(copy.deepcopy(model), *inputs[0])
            torch.cuda.synchronize()
    else:
        expected = train_loop(copy.deepcopy(model), *inputs[0])
    print(
        f"-- eager mode worked, {len(expected)} gradients, first one is "
        f"{expected[0].shape}, {expected[0].dtype}"
    )
else:
    if use_mixed:
        assert use_cuda, "mixed precision only works with cuda"
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            torch.cuda.synchronize()
            expected = model(*inputs[0])
            torch.cuda.synchronize()
    else:
        expected = model(*inputs[0])
    print(results_to_string(expected))
simple run with 2 inputs
torch.float32 (2, 1024, 512) [sum=-453]

Exporting

if hasattr(torch._dynamo.variables.misc, "LoggingLoggerVariable"):
    # A tweak to make torch.export.export work.
    torch._dynamo.variables.misc.LoggingLoggerVariable.call_method = lambda *_, **__: None


folder = "dump_models"
storage = {}

if backward:
    # onnxrt backend
    local_aot_ort, _ = make_aot_ort(dynamic=False, rewrite=True)

    optimized_mod = torch.compile(
        copy.deepcopy(model), backend=local_aot_ort, dynamic=False, fullgraph=True
    )

    with dump_onnx("llama_onnxrt", folder=folder, clean=True):
        if use_mixed:
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                torch.cuda.synchronize()
                expected_onnxrt = train_loop(optimized_mod, *inputs[0])
                torch.cuda.synchronize()
        else:
            expected_onnxrt = train_loop(optimized_mod, *inputs[0])
    assert_all_close(expected[0], expected_onnxrt[0], atol=1e-3)
    print(
        f"-- onnxrt backend worked, {len(expected_onnxrt)} gradients, first one is "
        f"{expected_onnxrt[0].shape}, {expected_onnxrt[0].dtype}"
    )

    # debugging backend
    aot_compiler = aot_autograd(
        fw_compiler=lambda *args, **kwargs: onnx_debug_backend(
            *args,
            dump_prefix=os.path.join(folder, "llama_debug"),
            target_opset=opset,
            storage=storage,
            **kwargs,
        ),
        decompositions=get_decomposition_table(),
    )
    onnx_mod = torch.compile(copy.deepcopy(model), backend=aot_compiler, fullgraph=True)

    if use_mixed:
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            torch.cuda.synchronize()
            got = train_loop(onnx_mod, *inputs[0])
            torch.cuda.synchronize()
    else:
        got = train_loop(onnx_mod, *inputs[0])
    assert_all_close(expected[0], got[0], atol=1e-2 if use_mixed else 1e-4)
    print(
        f"-- debug backend worked, {len(got)} gradients, first one is "
        f"{got[0].shape}, {got[0].dtype}"
    )

else:
    # onnxrt backend
    local_aot_ort, _ = make_aot_ort(dynamic=True, rewrite=True)
    optimized_mod = torch.compile(model, backend=local_aot_ort, fullgraph=True)
    with dump_onnx("llama_onnxrt", folder=folder, clean=True):
        if use_mixed:
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                torch.cuda.synchronize()
                expected_onnxrt = optimized_mod(*inputs[0])
                torch.cuda.synchronize()
        else:
            expected_onnxrt = optimized_mod(*inputs[0])
    assert_all_close(expected, expected_onnxrt, atol=1e-2)

    # debugging backend
    aot_compiler = aot_autograd(
        fw_compiler=lambda *args, **kwargs: onnx_debug_backend(
            *args,
            dump_prefix=os.path.join(folder, "llama_debug"),
            target_opset=17,
            storage=storage,
            **kwargs,
        )
    )

    onnx_mod = torch.compile(model, backend=aot_compiler, fullgraph=True)
    if use_mixed:
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            got = onnx_mod(*inputs[0])
    else:
        try:
            got = onnx_mod(*inputs[0])
        except Exception as e:
            print(f"ERROR: {e}")
            got = None
    if got is not None:
        assert_all_close(expected, got, atol=1 if use_mixed else 1e-3)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/onnx/_internal/_exporter_legacy.py:101: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
  warnings.warn(
Applied 9 of general pattern rewrite rules.
Applied 1 of general pattern rewrite rules.

For forward, there are two files, one onnx model and the graph module printed in a txt file. For backward, there are two onnx models. Then it is multiplied by the number of backends.

models = os.listdir(folder)
print(f"exported models: {models}")
exported models: ['llama_debug_0.onnx', 'llama_onnxrt_0.txt', 'llama_debug_0.txt', 'llama_onnxrt_0.onnx']

Inputs used by the debug backend

if "instance" in storage:
    feeds = storage["instance"][0]["inputs"][0]
    for k, v in feeds.items():
        print(f"-- {k} {v.dtype} {v.shape}")
-- input0 float32 (2, 1024, 512)
-- input1 float32 (512, 512)
-- input2 float32 (512, 512)
-- input3 float32 (512, 512)
-- input4 float32 (32,)
-- input5 int64 (1, 1024)
-- input6 float32 (2, 1, 1024, 1024)
-- input7 float32 (512, 512)

Let’s the first line of the graph module

if "instance" in storage:
    graph_module = storage["instance"][0]["graph_module"]
    print("\n".join(str(graph_module.graph).split("\n")[:10]))
graph():
    %primals_1 : [num_users=3] = placeholder[target=primals_1]
    %primals_2 : [num_users=1] = placeholder[target=primals_2]
    %primals_3 : [num_users=1] = placeholder[target=primals_3]
    %primals_4 : [num_users=1] = placeholder[target=primals_4]
    %primals_5 : [num_users=1] = placeholder[target=primals_5]
    %primals_6 : [num_users=1] = placeholder[target=primals_6]
    %primals_7 : [num_users=1] = placeholder[target=primals_7]
    %primals_8 : [num_users=1] = placeholder[target=primals_8]
    %t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%primals_2,), kwargs = {})

Comparison and execution

if "instance" in storage:
    if backward:
        print(f"-- {len(storage['instance'])} onnx models were creates")
        for i, inst in enumerate(storage["instance"]):
            print(f"  model {i}: {len(inst['inputs'])} runs")

        # deal with backward
        onnx_models = list(sorted([m for m in models if m.endswith(".onnx")]))
        assert len(onnx_models) == 4, f"unexpected value {onnx_models}"
        onnx_models = list(sorted([m for m in models if m.endswith(".onnx") and "_1" in m]))
        assert len(onnx_models) == 2, f"unexpected value {onnx_models}"
        model_onnxrt = os.path.join(folder, onnx_models[1])
        model_debug = os.path.join(folder, onnx_models[0])
    else:
        onnx_models = list(sorted([m for m in models if m.endswith(".onnx")]))
        if len(onnx_models) == 2:
            model_onnxrt = os.path.join(folder, onnx_models[1])
            model_debug = os.path.join(folder, onnx_models[0])
        else:
            model_debug = os.path.join(folder, onnx_models[0])
            # the following error may appear:
            # Node type 'Rank' from domain 'pkg.onnxscript.torch_lib.common' is unknown
            print(f"One model is missing, onnx_models={onnx_models}")
            model_onnxrt = model_debug

    print(f"model_onnxrt={model_onnxrt}")
    print(f"model_debug={model_debug}")
model_onnxrt=dump_models/llama_onnxrt_0.onnx
model_debug=dump_models/llama_debug_0.onnx

The inputs of both models

if "instance" in storage:
    print("onnxrt:", inputs_from_onnx_model(model_onnxrt))
    print("debug:", inputs_from_onnx_model(model_debug))
onnxrt: [('INPUT', 'primals_2', 1, (512, 512)), ('INPUT', 'primals_1', 1, (2, 1024, 512)), ('INPUT', 'primals_3', 1, (512, 512)), ('INPUT', 'primals_4', 1, (512, 512)), ('INPUT', 'primals_5', 1, (32,)), ('INPUT', 'primals_6', 7, (1, 1024)), ('INPUT', 'primals_7', 1, (2, 1, 1024, 1024)), ('INPUT', 'primals_8', 1, (512, 512))]
debug: [('INPUT', 'input0', 1, (2, 1024, 512)), ('INPUT', 'input1', 1, (512, 512)), ('INPUT', 'input2', 1, (512, 512)), ('INPUT', 'input3', 1, (512, 512)), ('INPUT', 'input4', 1, (32,)), ('INPUT', 'input5', 7, (1, 1024)), ('INPUT', 'input6', 1, (2, 1, 1024, 1024)), ('INPUT', 'input7', 1, (512, 512))]

Inputs are not the same. The first model has more and some inputs were moved into the initializer list into for model_debug.

if "instance" in storage:
    print("debug:", inputs_from_onnx_model(model_debug, init=True))
debug: [('INPUT', 'input0', 1, (2, 1024, 512)), ('INPUT', 'input1', 1, (512, 512)), ('INPUT', 'input2', 1, (512, 512)), ('INPUT', 'input3', 1, (512, 512)), ('INPUT', 'input4', 1, (32,)), ('INPUT', 'input5', 7, (1, 1024)), ('INPUT', 'input6', 1, (2, 1, 1024, 1024)), ('INPUT', 'input7', 1, (512, 512)), ('INIT', 'init7_s2_2048_512', 7, (2,)), ('INIT', 'init7_s3_2_1024_512', 7, (3,)), ('INIT', 'init7_s4_2_1024_-1_64', 7, (4,)), ('INIT', 'init7_s1_1', 7, (1,)), ('INIT', 'init1_s_', 1, ()), ('INIT', 'init7_s3_16_1024_64', 7, (3,)), ('INIT', 'init7_s3_16_64_1024', 7, (3,)), ('INIT', 'init1_s_2', 1, ()), ('INIT', 'init7_s3_16_1024_1024', 7, (3,)), ('INIT', 'init7_s2_0_2', 7, (2,)), ('INIT', 'init7_s2_32_32', 7, (2,))]

Optimization and Verification

Let’s try the model with a python backend (reference implementation). First step, onnxscript uses many functions. The reference evaluation expects every function to be defined so the order of functions in the model matters. No recursivity is allowed by this runtime. We need to reorder as function Rank is usually placed at the end of the model.

Let’s load the model and optimize them.

if "instance" in storage:
    debug = onnx.load(model_debug)
    try:
        onnxrt = optimize_model_proto_oxs(onnx.load(model_onnxrt))
    except ImportError as e:
        print("missing library", e)
        onnxrt = debug

Let’s apply onnxruntime optimization

if "instance" in storage and ortopt:
    providers = (
        [("CUDAExecutionProvider", {}), ("CPUExecutionProvider", {})]
        if use_cuda
        else ["CPUExecutionProvider"]
    )
    with open(model_onnxrt.replace(".onnx", ".before.opt.onnx"), "wb") as f:
        f.write(onnxrt.SerializeToString())
    print(f"run onnxruntime optimization on {model_onnxrt}")
    optimized = model_onnxrt.replace(".onnx", ".opt.onnx")
    ort_optimize(onnxrt, output=optimized, providers=providers)
    onnxrt = onnx.load(optimized)

    print(f"run onnxruntime optimization on {model_debug}")
    optimized = model_debug.replace(".onnx", ".opt.onnx")
    ort_optimize(debug, output=optimized, disable_aot=True, providers=providers)
    debug = onnx.load(optimized)
run onnxruntime optimization on dump_models/llama_onnxrt_0.onnx
run onnxruntime optimization on dump_models/llama_debug_0.onnx

For what’s following, we need to build two lists of matching inputs.

if "instance" in storage:
    print("build_matching_inputs")
    feedsrt = build_matching_inputs(model_debug, feeds, model_onnxrt)
    print("done")
build_matching_inputs
done

We check both models are running.

if "instance" in storage:
    out_onnxrt = ExtendedReferenceEvaluator(onnxrt).run(None, feedsrt)
    out_debug = ExtendedReferenceEvaluator(debug).run(None, feeds)
    assert out_onnxrt
    assert out_debug

# assert_all_close(out_onnxrt, out_debug)

Side by side

if "instance" in storage:
    res1, res2, align, dc = compare_onnx_execution(
        onnxrt,
        debug,
        verbose=1,
        raise_exc=True,
        inputs=(feedsrt, feeds),
    )
    text = dc.to_str(res1, res2, align, column_size=90)
    print(text)
[compare_onnx_execution] execute with 2 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 94 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 94 results (first model)
[compare_onnx_execution] got 82 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 105 pairs
[compare_onnx_execution] done
001 = | INITIA int64    1:2                  USAA                 _val_301                         | INITIA int64    1:2                  USAA                 init7_s2_2048_512
002 - | INITIA int64                         AAAA                 aten_unsqueeze_75_dim_0          |
003 ~ | INITIA int64    1:4                  CIKK                 _val_274                         | INITIA int64    1:3                  CKSA                 init7_s3_2_1024_512
004 - | INITIA int64                         BAAA                 aten_unsqueeze_311_dim_0         |
005 - | INITIA int64    1:3                  QMKA                 _val_269                         |
006 ~ | INITIA int64    1:3                  CKSA                 _val_96                          | INITIA int64    1:4                  CKZM                 init7_s4_2_1024_-1_64
007 - | INITIA int64                         CAAA                 aten_unsqueeze_159_dim_0         |
008 ~ | INITIA int64    1:3                  QKMA                 _val_264                         | INITIA int64    1:1                  BAAA                 init7_s1_1
009 - | INITIA int64    1:4                  CKZM                 _val_137                         |
010 ~ | INITIA int64    1:2                  GGAA                 splits                           | INITIA int64    1:2                  ACAA                 init7_s2_0_2
011 - | INITIA float32                       IAAA                 _val_276                         |
012 ~ | INITIA int64    1:3                  CKZA                 _val_298                         | INITIA int64    1:3                  QKMA                 init7_s3_16_1024_64
013 ~ | INITIA int64    1:3                  QKKA                 _val_287                         | INITIA int64    1:3                  QMKA                 init7_s3_16_64_1024
014 = | INITIA int64    1:2                  GGAA                 splits_token_9                   | INITIA int64    1:2                  GGAA                 init7_s2_32_32
015 ~ | INITIA int64    1:4                  CIKM                 _val_293                         | INITIA int64    1:3                  QKKA                 init7_s3_16_1024_1024
016 + |                                                                                            | INPUT  float32  3:2x1024x512         PYUE                 input0
017 = | INPUT  float32  2:512x512            XZAG                 primals_2                        | INPUT  float32  2:512x512            XZAG                 input1
018 - | INPUT  float32  3:2x1024x512         PYUE                 primals_1                        |
019 = | INPUT  float32  2:512x512            PVAD                 primals_3                        | INPUT  float32  2:512x512            PVAD                 input2
020 = | INPUT  float32  2:512x512            GBBF                 primals_4                        | INPUT  float32  2:512x512            GBBF                 input3
021 = | INPUT  float32  1:32                 DAAA                 primals_5                        | INPUT  float32  1:32                 DAAA                 input4
022 = | INPUT  int64    2:1x1024             KAQG                 primals_6                        | INPUT  int64    2:1x1024             KAQG                 input5
023 = | INPUT  float32  4:2x1x1024x1024      AAAA                 primals_7                        | INPUT  float32  4:2x1x1024x1024      AAAA                 input6
024 = | INPUT  float32  2:512x512            CTSV                 primals_8                        | INPUT  float32  2:512x512            CTSV                 input7
025 - | RESULT float32  2:512x512            CTSV Identity        t_6                              |
026 = | RESULT int64    3:1x1x1024           KAQG Unsqueeze       unsqueeze_2                      | RESULT int64    3:1x1x1024           KAQG Unsqueeze       unsqueeze_2
027 = | RESULT float32  3:1x1x1024           KAQG Cast            _to_copy                         | RESULT float32  3:1x1x1024           KAQG Cast            _to_copy
028 - | RESULT float32  2:1x32               DAAA Unsqueeze       unsqueeze                        |
029 = | RESULT float32  3:1x32x1             DAAA Unsqueeze       unsqueeze_1                      | RESULT float32  3:1x32x1             DAAA Unsqueeze       unsqueeze_1
030 + |                                                                                            | RESULT float32  3:1x32x1024          EFXM MatMul          bmm
031 - | RESULT float32  3:1x1024x32          XCHM FusedMatMul     transpose_3                      |
032 - | RESULT float32  3:1x1024x64          VFPY Concat          cat                              |
033 ~ | RESULT float32  3:1x1024x64          GSEC Sin             sin                              | RESULT float32  3:1x64x1024          JKJK Concat          cat_token_5
034 ~ | RESULT float32  4:1x1x1024x64        GSEC Unsqueeze       unsqueeze_4                      | RESULT float32  3:1x64x1024          RMRM Sin             sin_token_7
035 + |                                                                                            | RESULT float32  4:1x1x64x1024        RMRM Unsqueeze       Opset8
036 = | RESULT float32  4:1x1024x1x64        GSEC Transpose       Transpose_token_4_out0           | RESULT float32  4:1x1024x1x64        GSEC Transpose       Transpose_token_10_out0
037 = | RESULT float32  2:2048x512           PYUE Reshape         view                             | RESULT float32  2:2048x512           PYUE Reshape         output_2
038 ~ | RESULT float32  2:2048x512           SUEW FusedMatMul     mm_1                             | RESULT float32  2:2048x512           SUEW Gemm            mm_1
039 - | RESULT float32  3:2x1024x512         SUEW Reshape         _unsafe_view_1                   |
040 = | RESULT float32  4:2x1024x8x64        SUEW Reshape         view_4                           | RESULT float32  4:2x1024x8x64        SUEW Reshape         view_4
041 = | RESULT float32  4:2x1024x8x32        UFUT Split           Slice_263                        | RESULT float32  4:2x1024x8x32        UFUT Split           SlicesSplitPattern--slice_Tensor
042 = | RESULT float32  4:2x1024x8x32        ZOKD Split           Slice_280                        | RESULT float32  4:2x1024x8x32        ZOKD Split           SlicesSplitPattern--slice_Tensor
043 = | RESULT float32  4:2x1024x8x32        BMQX Neg             Neg_290                          | RESULT float32  4:2x1024x8x32        BMQX Neg             neg2
044 = | RESULT float32  4:2x1024x8x64        URKR Concat          Concat_294                       | RESULT float32  4:2x1024x8x64        URKR Concat          cat3
045 = | RESULT float32  4:2x1024x8x64        XKLR Mul             Mul_315                          | RESULT float32  4:2x1024x8x64        XKLR Mul             mul_Tensor10
046 + |                                                                                            | RESULT float32  3:1x64x1024          NHNH Cos             cos_token_13
047 ~ | RESULT float32  3:1x1024x64          CJYF Cos             cos                              | RESULT float32  4:1x1x64x1024        NHNH Unsqueeze       Opset7
048 - | RESULT float32  4:1x1x1024x64        CJYF Unsqueeze       unsqueeze_3                      |
049 = | RESULT float32  4:1x1024x1x64        CJYF Transpose       Transpose_token_6_out0           | RESULT float32  4:1x1024x1x64        CJYF Transpose       Transpose_token_16_out0
050 = | RESULT float32  4:2x1024x8x64        QRHY Mul             Mul_313                          | RESULT float32  4:2x1024x8x64        QRHY Mul             mul_Tensor9
051 = | RESULT float32  4:2x1024x8x64        NBTQ Add             Add_317                          | RESULT float32  4:2x1024x8x64        NBTQ Add             add_Tensor2
052 = | RESULT float32  4:2x8x64x1024        MCDE Transpose       transpose_4                      | RESULT float32  4:2x8x64x1024        MCDE Transpose       transpose_4
053 + |                                                                                            | RESULT float32  4:1x1x1024x64        GSEC Transpose       output_5
054 - | RESULT float32  3:16x64x1024         MCDE Reshape         _unsafe_view_4                   |
055 ~ | RESULT float32  2:2048x512           MVCJ FusedMatMul     mm                               | RESULT float32  2:2048x512           PYUE Reshape         output_1
056 ~ | RESULT float32  3:2x1024x512         MVCJ Reshape         _unsafe_view                     | RESULT float32  2:2048x512           MVCJ Gemm            mm
057 = | RESULT float32  4:2x1024x8x64        MVCJ Reshape         view_3                           | RESULT float32  4:2x1024x8x64        MVCJ Reshape         view_3
058 = | RESULT float32  4:2x8x1024x64        RQST Transpose       transpose                        | RESULT float32  4:2x8x1024x64        RQST Transpose       transpose
059 = | RESULT float32  4:2x8x1024x32        KJGH Split           slice_4                          | RESULT float32  4:2x8x1024x32        KJGH Split           slice_4
060 = | RESULT float32  4:2x8x1024x32        HHML Split           slice_5                          | RESULT float32  4:2x8x1024x32        HHML Split           slice_5
061 = | RESULT float32  4:2x8x1024x32        TTOP Neg             neg                              | RESULT float32  4:2x8x1024x32        TTOP Neg             neg
062 = | RESULT float32  4:2x8x1024x64        CDUW Concat          cat_1                            | RESULT float32  4:2x8x1024x64        CDUW Concat          cat_1
063 = | RESULT float32  4:2x8x1024x64        NUIG Mul             mul_3                            | RESULT float32  4:2x8x1024x64        NUIG Mul             mul_3
064 + |                                                                                            | RESULT float32  4:1x1x1024x64        CJYF Transpose       output_4
065 = | RESULT float32  4:2x8x1024x64        OUMI Mul             mul_2                            | RESULT float32  4:2x8x1024x64        OUMI Mul             mul_2
066 = | RESULT float32  4:2x8x1024x64        BPUO Add             add                              | RESULT float32  4:2x8x1024x64        BPUO Add             add
067 - | RESULT float32  3:16x1024x64         BPUO Reshape         _unsafe_view_3                   |
068 - | RESULT float32  3:16x1024x1024       CTYN MatMul          bmm_1                            |
069 - | RESULT float32  4:2x8x1024x1024      CTYN Reshape         view_9                           |
070 ~ | RESULT float32  4:2x8x1024x1024      EFQC Div             div                              | RESULT float32  4:2x8x1024x1024      EFQC FusedMatMul     div
071 = | RESULT float32  4:2x8x1024x1024      EFQC Add             add_2                            | RESULT float32  4:2x8x1024x1024      EFQC Add             add_2
072 = | RESULT float32  4:2x8x1024x1024      NNNO Softmax         _softmax                         | RESULT float32  4:2x8x1024x1024      NNNO Softmax         output_8
073 - | RESULT float32  3:16x1024x1024       NNNO Reshape         view_10                          |
074 ~ | RESULT float32  2:2048x512           UUMW FusedMatMul     mm_2                             | RESULT float32  2:2048x512           PYUE Reshape         output_3
075 ~ | RESULT float32  3:2x1024x512         UUMW Reshape         _unsafe_view_2                   | RESULT float32  2:2048x512           UUMW Gemm            mm_2
076 = | RESULT float32  4:2x1024x8x64        UUMW Reshape         view_5                           | RESULT float32  4:2x1024x8x64        UUMW Reshape         view_5
077 = | RESULT float32  4:2x8x1024x64        KGKZ Transpose       transpose_2                      | RESULT float32  4:2x8x1024x64        KGKZ Transpose       transpose_2
078 ~ | RESULT float32  3:16x1024x64         KGKZ Reshape         _unsafe_view_5                   | RESULT float32  4:2x8x1024x64        GMGA MatMul          view_11
079 ~ | RESULT float32  3:16x1024x64         GMGA MatMul          bmm_2                            | RESULT float32  4:2x1024x8x64        IJJX Transpose       transpose_5
080 ~ | RESULT float32  4:2x8x1024x64        GMGA Reshape         view_11                          | RESULT float32  2:2048x512           IJJX Reshape         output_12
081 ~ | RESULT float32  4:2x1024x8x64        IJJX Transpose       transpose_5                      | RESULT float32  2:2048x512           GHZE Gemm            mm_3
082 ~ | RESULT float32  3:2x1024x512         IJJX Reshape         view_12                          | RESULT float32  3:2x1024x512         GHZE Reshape         output_0
083 + |                                                                                            | RESULT float32  3:16x1024x1024       NNNO Reshape         output_9
084 ~ | RESULT float32  2:2048x512           IJJX Reshape         view_13                          | RESULT float32  3:16x64x1024         MCDE Reshape         output_7
085 ~ | RESULT float32  2:2048x512           GHZE FusedMatMul     mm_3                             | RESULT float32  3:16x1024x64         BPUO Reshape         output_6
086 ~ | RESULT float32  3:2x1024x512         GHZE Reshape         _unsafe_view_6                   | RESULT float32  3:16x1024x64         KGKZ Reshape         output_10
087 + |                                                                                            | RESULT float32  2:512x512            RYSA Transpose       output_11
088 - | RESULT float32  3:16x1024x1024       NNNO Transpose       transpose_7                      |
089 - | RESULT float32  4:2x8x1024x1024      NNNO Identity        detach_3                         |
090 ~ | RESULT float32  3:16x1024x64         MCDE Transpose       transpose_10                     | OUTPUT float32  3:2x1024x512         GHZE                 output_0
091 ~ | RESULT float32  3:16x64x1024         BPUO Transpose       transpose_9                      | OUTPUT float32  2:2048x512           PYUE                 output_1
092 - | RESULT float32  3:16x64x1024         KGKZ Transpose       transpose_8                      |
093 = | OUTPUT float32  2:2048x512           PYUE                 view                             | OUTPUT float32  2:2048x512           PYUE                 output_2
094 - | OUTPUT float32  2:512x512            CTSV                 t_6                              |
095 ~ | OUTPUT float32  3:16x64x1024         KGKZ                 transpose_8                      | OUTPUT float32  2:2048x512           PYUE                 output_3
096 ~ | OUTPUT float32  3:1x1024x64          VFPY                 cat                              | OUTPUT float32  4:1x1x1024x64        CJYF                 output_4
097 + |                                                                                            | OUTPUT float32  4:1x1x1024x64        GSEC                 output_5
098 + |                                                                                            | OUTPUT float32  3:16x1024x64         BPUO                 output_6
099 ~ | OUTPUT float32  3:16x64x1024         BPUO                 transpose_9                      | OUTPUT float32  3:16x64x1024         MCDE                 output_7
100 - | OUTPUT float32  3:16x1024x64         MCDE                 transpose_10                     |
101 = | OUTPUT float32  4:2x8x1024x1024      NNNO                 detach_3                         | OUTPUT float32  4:2x8x1024x1024      NNNO                 output_8
102 = | OUTPUT float32  3:16x1024x1024       NNNO                 transpose_7                      | OUTPUT float32  3:16x1024x1024       NNNO                 output_9
103 ~ | OUTPUT float32  2:2048x512           IJJX                 view_13                          | OUTPUT float32  3:16x1024x64         KGKZ                 output_10
104 + |                                                                                            | OUTPUT float32  2:512x512            RYSA                 output_11
105 ~ | OUTPUT float32  3:2x1024x512         GHZE                 _unsafe_view_6                   | OUTPUT float32  2:2048x512           IJJX                 output_12

Total running time of the script: (0 minutes 6.666 seconds)

Related examples

101: A custom backend for torch

101: A custom backend for torch

301: Compares LLAMA exporters

301: Compares LLAMA exporters

102: Fuse kernels in a small Llama Model

102: Fuse kernels in a small Llama Model

Gallery generated by Sphinx-Gallery