301: Compares LLAMA exporters for onnxrt backend

The script compares exported models in pytorch using onnxrt backend. It tries to do a side by side of the execution of both models.

To run the script:

python _doc/examples/plot_llama_diff_dort --help

The following example compares the forward step for mixed precision on cuda and produces all the intermediate onnx graphs.

python _doc/examples/plot_llama_diff_dort.py --part model --ortopt 1 \
        --cuda 1 --backward 0 --mixed 1

You may use --mixed=1 to compare the backward graphs.

Some helpers

from experimental_experiment.args import get_parsed_args

script_args = get_parsed_args(
    "plot_llama_diff_export",
    description=__doc__,
    part=("model", "one value among model, ..."),
    ortopt=(1, "run onnxruntime optimization"),
    backward=(0, "does one operator for backward"),
    cuda=(0, "use cuda or not"),
    mixed=(0, "use miwed precision"),
    opset=(18, "onnx opset"),
    expose="part,exporter,ortopt,cuda,mixed,opset",
)


import copy
import os
import warnings
import logging

try:
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        import onnxruntime

        has_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers()
except ImportError:
    print("onnxruntime not available.")
    import sys

    sys.exit(0)

import onnx
from onnx_array_api.reference import compare_onnx_execution, ExtendedReferenceEvaluator
import torch
from torch._dynamo.backends.common import aot_autograd
from experimental_experiment.ext_test_case import unit_test_going
from experimental_experiment.convert.convert_helper import (
    ort_optimize,
    optimize_model_proto_oxs,
)
from experimental_experiment.torch_models.llama_helper import get_llama_model
from experimental_experiment.torch_models.dump_helper import (
    assert_all_close,
    dump_onnx,
    reorder_functions_in_proto,
    inputs_from_onnx_model,
    build_matching_inputs,
    results_to_string,
)
from experimental_experiment.torch_models.training_helper import (
    train_loop,
    make_aot_ort,
)
from experimental_experiment.torch_dynamo import (
    onnx_debug_backend,
    get_decomposition_table,
)

has_cuda = has_cuda and torch.cuda.device_count() > 0
logging.disable(logging.ERROR)
provider = "cuda" if has_cuda else "cpu"

The exporting functions

print(f"part={script_args.part}")
ortopt = script_args.ortopt in (1, "1")
print(f"ortopt={ortopt}")
backward = script_args.backward in (1, "1")
print(f"backward={backward}")
use_cuda = script_args.cuda in (1, "1")
print(f"cuda={use_cuda}")
use_mixed = script_args.mixed in (1, "1")
print(f"mixed={use_mixed}")
opset = int(script_args.opset)
print(f"opset={opset}")
part=model
ortopt=True
backward=False
cuda=False
mixed=False
opset=18

Model and data

if unit_test_going():
    kwargs = dict(input_dims=[(2, 1024)] * 2)
else:
    kwargs = dict(
        input_dims=[(2, 1024)] * 2,
        _attn_implementation="eager",
        num_hidden_layers=1,
        hidden_size=512,
        vocab_size=4000,
        intermediate_size=2000,
        max_position_embeddings=2048,
        num_attention_heads=8,
    )

if script_args.part == "model":
    model, inputs = get_llama_model(**kwargs)
else:
    raise RuntimeError(f"Unexpected value for part={script_args.part!r}")

if use_cuda:
    model = model.to("cuda")
    inputs = [[i.to("cuda") for i in inp] for inp in inputs]

print(f"simple run with {len(inputs)} inputs")
if backward:
    if use_mixed:
        assert use_cuda, "mixed precision only works with cuda"
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            torch.cuda.synchronize()
            expected = train_loop(copy.deepcopy(model), *inputs[0])
            torch.cuda.synchronize()
    else:
        expected = train_loop(copy.deepcopy(model), *inputs[0])
    print(
        f"-- eager mode worked, {len(expected)} gradients, first one is "
        f"{expected[0].shape}, {expected[0].dtype}"
    )
else:
    if use_mixed:
        assert use_cuda, "mixed precision only works with cuda"
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            torch.cuda.synchronize()
            expected = model(*inputs[0])
            torch.cuda.synchronize()
    else:
        expected = model(*inputs[0])
    print(results_to_string(expected))
simple run with 2 inputs
1 results
  torch.float32 (2, 1024, 512) [sum=1.39e+04]

Exporting

if hasattr(torch._dynamo.variables.misc, "LoggingLoggerVariable"):
    # A tweak to make torch.export.export work.
    torch._dynamo.variables.misc.LoggingLoggerVariable.call_method = lambda *_, **__: None


folder = "dump_models"
storage = {}

if backward:
    # onnxrt backend
    local_aot_ort, _ = make_aot_ort(dynamic=False, rewrite=True)

    optimized_mod = torch.compile(
        copy.deepcopy(model), backend=local_aot_ort, dynamic=False, fullgraph=True
    )

    with dump_onnx("llama_onnxrt", folder=folder, clean=True):
        if use_mixed:
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                torch.cuda.synchronize()
                expected_onnxrt = train_loop(optimized_mod, *inputs[0])
                torch.cuda.synchronize()
        else:
            expected_onnxrt = train_loop(optimized_mod, *inputs[0])
    assert_all_close(expected[0], expected_onnxrt[0], atol=1e-3)
    print(
        f"-- onnxrt backend worked, {len(expected_onnxrt)} gradients, first one is "
        f"{expected_onnxrt[0].shape}, {expected_onnxrt[0].dtype}"
    )

    # debugging backend
    aot_compiler = aot_autograd(
        fw_compiler=lambda *args, **kwargs: onnx_debug_backend(
            *args,
            dump_prefix=os.path.join(folder, "llama_debug"),
            target_opset=opset,
            storage=storage,
            **kwargs,
        ),
        decompositions=get_decomposition_table(),
    )
    onnx_mod = torch.compile(copy.deepcopy(model), backend=aot_compiler, fullgraph=True)

    if use_mixed:
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            torch.cuda.synchronize()
            got = train_loop(onnx_mod, *inputs[0])
            torch.cuda.synchronize()
    else:
        got = train_loop(onnx_mod, *inputs[0])
    assert_all_close(expected[0], got[0], atol=1e-2 if use_mixed else 1e-4)
    print(
        f"-- debug backend worked, {len(got)} gradients, first one is "
        f"{got[0].shape}, {got[0].dtype}"
    )

else:
    # onnxrt backend
    local_aot_ort, _ = make_aot_ort(dynamic=True, rewrite=True)
    optimized_mod = torch.compile(model, backend=local_aot_ort, fullgraph=True)
    with dump_onnx("llama_onnxrt", folder=folder, clean=True):
        if use_mixed:
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                torch.cuda.synchronize()
                expected_onnxrt = optimized_mod(*inputs[0])
                torch.cuda.synchronize()
        else:
            expected_onnxrt = optimized_mod(*inputs[0])
    assert_all_close(expected, expected_onnxrt, atol=1e-2)

    # debugging backend
    aot_compiler = aot_autograd(
        fw_compiler=lambda *args, **kwargs: onnx_debug_backend(
            *args,
            dump_prefix=os.path.join(folder, "llama_debug"),
            target_opset=17,
            storage=storage,
            **kwargs,
        )
    )

    onnx_mod = torch.compile(model, backend=aot_compiler, fullgraph=True)
    if use_mixed:
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            got = onnx_mod(*inputs[0])
    else:
        try:
            got = onnx_mod(*inputs[0])
        except Exception as e:
            print(f"ERROR: {e}")
            got = None
    if got is not None:
        assert_all_close(expected, got, atol=1 if use_mixed else 1e-3)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/onnx/_internal/_exporter_legacy.py:109: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
  warnings.warn(
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/onnx/_internal/fx/onnxfunction_dispatcher.py:505: FutureWarning: 'onnxscript.values.TracedOnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
  self.param_schema = self.onnxfunction.param_schemas()

For forward, there are two files, one onnx model and the graph module printed in a txt file. For backward, there are two onnx models. Then it is multiplied by the number of backends.

models = os.listdir(folder)
print(f"exported models: {models}")
exported models: ['llama_debug_0.onnx', 'llama_onnxrt_0.txt', 'llama_debug_0.txt', 'llama_onnxrt_0.onnx']

Inputs used by the debug backend

if "instance" in storage:
    feeds = storage["instance"][0]["inputs"][0]
    for k, v in feeds.items():
        print(f"-- {k} {v.dtype} {v.shape}")
-- input0 int64 (2, 1024)
-- input1 float32 (4000, 512)
-- input2 float32 (2, 1024)
-- input3 float32 (32,)
-- input4 float32 (512,)
-- input5 float32 (512, 512)
-- input6 float32 (512, 512)
-- input7 float32 (512, 512)
-- input8 float32 (512, 512)
-- input9 float32 (512,)
-- input10 float32 (2000, 512)
-- input11 float32 (2000, 512)
-- input12 float32 (512, 2000)
-- input13 float32 (512,)

Let’s the first line of the graph module

if "instance" in storage:
    graph_module = storage["instance"][0]["graph_module"]
    print("\n".join(str(graph_module.graph).split("\n")[:10]))
graph():
    %primals_1 : [num_users=2] = placeholder[target=primals_1]
    %primals_2 : [num_users=1] = placeholder[target=primals_2]
    %primals_3 : [num_users=1] = placeholder[target=primals_3]
    %primals_4 : [num_users=1] = placeholder[target=primals_4]
    %primals_5 : [num_users=2] = placeholder[target=primals_5]
    %primals_6 : [num_users=1] = placeholder[target=primals_6]
    %primals_7 : [num_users=1] = placeholder[target=primals_7]
    %primals_8 : [num_users=1] = placeholder[target=primals_8]
    %primals_9 : [num_users=1] = placeholder[target=primals_9]

Comparison and execution

if "instance" in storage:
    if backward:
        print(f"-- {len(storage['instance'])} onnx models were creates")
        for i, inst in enumerate(storage["instance"]):
            print(f"  model {i}: {len(inst['inputs'])} runs")

        # deal with backward
        onnx_models = list(sorted([m for m in models if m.endswith(".onnx")]))
        assert len(onnx_models) == 4, f"unexpected value {onnx_models}"
        onnx_models = list(sorted([m for m in models if m.endswith(".onnx") and "_1" in m]))
        assert len(onnx_models) == 2, f"unexpected value {onnx_models}"
        model_onnxrt = os.path.join(folder, onnx_models[1])
        model_debug = os.path.join(folder, onnx_models[0])
    else:
        onnx_models = list(sorted([m for m in models if m.endswith(".onnx")]))
        if len(onnx_models) == 2:
            model_onnxrt = os.path.join(folder, onnx_models[1])
            model_debug = os.path.join(folder, onnx_models[0])
        else:
            model_debug = os.path.join(folder, onnx_models[0])
            # the following error may appear:
            # Node type 'Rank' from domain 'pkg.onnxscript.torch_lib.common' is unknown
            print(f"One model is missing, onnx_models={onnx_models}")
            model_onnxrt = model_debug

    print(f"model_onnxrt={model_onnxrt}")
    print(f"model_debug={model_debug}")
model_onnxrt=dump_models/llama_onnxrt_0.onnx
model_debug=dump_models/llama_debug_0.onnx

The inputs of both models

if "instance" in storage:
    print("onnxrt:", inputs_from_onnx_model(model_onnxrt))
    print("debug:", inputs_from_onnx_model(model_debug))
onnxrt: [('INPUT', 'primals_2', 1, (4000, 512)), ('INPUT', 'primals_1', 7, (2, 1024)), ('INPUT', 'primals_3', 1, (2, 1024)), ('INPUT', 'primals_4', 1, (32,)), ('INPUT', 'primals_6', 1, (512, 512)), ('INPUT', 'primals_7', 1, (512, 512)), ('INPUT', 'primals_8', 1, (512, 512)), ('INPUT', 'primals_9', 1, (512, 512)), ('INPUT', 'primals_11', 1, (2000, 512)), ('INPUT', 'primals_12', 1, (2000, 512)), ('INPUT', 'primals_13', 1, (512, 2000)), ('INPUT', 'primals_5', 1, (512,)), ('INPUT', 'primals_10', 1, (512,)), ('INPUT', 'primals_14', 1, (512,))]
debug: [('INPUT', 'input0', 7, (2, 1024)), ('INPUT', 'input1', 1, (4000, 512)), ('INPUT', 'input2', 1, (2, 1024)), ('INPUT', 'input3', 1, (32,)), ('INPUT', 'input4', 1, (512,)), ('INPUT', 'input5', 1, (512, 512)), ('INPUT', 'input6', 1, (512, 512)), ('INPUT', 'input7', 1, (512, 512)), ('INPUT', 'input8', 1, (512, 512)), ('INPUT', 'input9', 1, (512,)), ('INPUT', 'input10', 1, (2000, 512)), ('INPUT', 'input11', 1, (2000, 512)), ('INPUT', 'input12', 1, (512, 2000)), ('INPUT', 'input13', 1, (512,))]

Inputs are not the same. The first model has more and some inputs were moved into the initializer list into for model_debug.

if "instance" in storage:
    print("debug:", inputs_from_onnx_model(model_debug, init=True))
debug: [('INPUT', 'input0', 7, (2, 1024)), ('INPUT', 'input1', 1, (4000, 512)), ('INPUT', 'input2', 1, (2, 1024)), ('INPUT', 'input3', 1, (32,)), ('INPUT', 'input4', 1, (512,)), ('INPUT', 'input5', 1, (512, 512)), ('INPUT', 'input6', 1, (512, 512)), ('INPUT', 'input7', 1, (512, 512)), ('INPUT', 'input8', 1, (512, 512)), ('INPUT', 'input9', 1, (512,)), ('INPUT', 'input10', 1, (2000, 512)), ('INPUT', 'input11', 1, (2000, 512)), ('INPUT', 'input12', 1, (512, 2000)), ('INPUT', 'input13', 1, (512,)), ('INIT', 'init7_s_0', 7, ()), ('INIT', 'init7_s_1024', 7, ()), ('INIT', 'init7_s_1', 7, ()), ('INIT', 'init7_s2_1024_1024', 7, (2,)), ('INIT', 'init7_s2_-1_1', 7, (2,)), ('INIT', 'init7_s1_1', 7, (1,)), ('INIT', 'init7_s4_2_1_1024_1024', 7, (4,)), ('INIT', 'init1_s_', 1, ()), ('INIT', 'init1_s1_', 1, (1,)), ('INIT', 'init1_s_2', 1, ()), ('INIT', 'init1_s1_2', 1, (1,)), ('INIT', 'init1_s_3', 1, ()), ('INIT', 'init7_s2_2048_512', 7, (2,)), ('INIT', 'init7_s3_2_1024_512', 7, (3,)), ('INIT', 'init7_s4_2_1024_-1_64', 7, (4,)), ('INIT', 'init7_s3_16_1024_64', 7, (3,)), ('INIT', 'init7_s3_16_64_1024', 7, (3,)), ('INIT', 'init1_s_4', 1, ()), ('INIT', 'init7_s3_16_1024_1024', 7, (3,)), ('INIT', 'init7_s3_2_1024_2000', 7, (3,)), ('INIT', 'init7_s2_2048_2000', 7, (2,)), ('INIT', 'init7_s2_0_1', 7, (2,)), ('INIT', 'init7_s2_1_2', 7, (2,)), ('INIT', 'init7_s2_0_2', 7, (2,)), ('INIT', 'init7_s2_32_32', 7, (2,))]

Optimization and Verification

Let’s try the model with a python backend (reference implementation). First step, onnxscript uses many functions. The reference evaluation expects every function to be defined so the order of functions in the model matters. No recursivity is allowed by this runtime. We need to reorder as function Rank is usually placed at the end of the model.

Let’s load the model and optimize them.

if "instance" in storage:
    debug = onnx.load(model_debug)
    try:
        onnxrt = optimize_model_proto_oxs(onnx.load(model_onnxrt))
    except ImportError as e:
        print("missing library", e)
        onnxrt = debug
Applied 15 of general pattern rewrite rules.

Let’s apply onnxruntime optimization

if "instance" in storage and ortopt:
    providers = (
        [("CUDAExecutionProvider", {}), ("CPUExecutionProvider", {})]
        if use_cuda
        else ["CPUExecutionProvider"]
    )
    with open(model_onnxrt.replace(".onnx", ".before.opt.onnx"), "wb") as f:
        f.write(onnxrt.SerializeToString())
    print(f"run onnxruntime optimization on {model_onnxrt}")
    optimized = model_onnxrt.replace(".onnx", ".opt.onnx")
    ort_optimize(onnxrt, output=optimized, providers=providers)
    onnxrt = onnx.load(optimized)

    print(f"run onnxruntime optimization on {model_debug}")
    optimized = model_debug.replace(".onnx", ".opt.onnx")
    ort_optimize(debug, output=optimized, disable_aot=True, providers=providers)
    debug = onnx.load(optimized)
run onnxruntime optimization on dump_models/llama_onnxrt_0.onnx
run onnxruntime optimization on dump_models/llama_debug_0.onnx

For what’s following, we need to build two lists of matching inputs.

if "instance" in storage:
    print("build_matching_inputs")
    feedsrt = build_matching_inputs(model_debug, feeds, model_onnxrt)
    print("done")
build_matching_inputs
done

We check both models are running.

if "instance" in storage:
    out_onnxrt = ExtendedReferenceEvaluator(onnxrt).run(None, feedsrt)
    out_debug = ExtendedReferenceEvaluator(debug).run(None, feeds)
    assert out_onnxrt
    assert out_debug

# assert_all_close(out_onnxrt, out_debug)

Side by side

if "instance" in storage:
    res1, res2, align, dc = compare_onnx_execution(
        onnxrt,
        debug,
        verbose=1,
        raise_exc=True,
        inputs=(feedsrt, feeds),
    )
    text = dc.to_str(res1, res2, align, column_size=90)
    print(text)
[compare_onnx_execution] execute with 2 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 177 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 177 results (first model)
[compare_onnx_execution] got 171 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 212 pairs
[compare_onnx_execution] done
001 ~ | INITIA float32                       ?AAA                 _val_22                          | INITIA float32  1:1                  AAAA                 _reshape_init1_s_0
002 - | INITIA float32  4:1024x1x2x1024      ????                 _val_415                         |
003 - | INITIA int64    2:1x1                AAAA                 _val_436                         |
004 ~ | INITIA int64    1:4                  CIKK                 _val_503                         | INITIA float32  1:1                  AAAA                 _reshape_init1_s_303
005 ~ | INITIA int64    1:3                  QMKA                 _val_464                         | INITIA float32  3:1x1x1024           KAQG                 _to_copy
006 - | INITIA int64                         BAAA                 dim_0__4                         |
007 ~ | INITIA int64    1:3                  CKSA                 _val_585                         | INITIA int64    1:3                  QKKA                 init7_s3_16_1024_1024
008 ~ | INITIA float32  4:1x2x1024x1024      ????                 _val_438                         | INITIA float32  4:2x1x1024x1024      ????                 expand_1
009 ~ | INITIA int64    1:3                  CKZA                 _val_544                         | INITIA int64    1:1                  BAAA                 init7_s1_1
010 - | INITIA int64    2:1024x1             KAQG                 _val_413                         |
011 ~ | INITIA float32                       AAAA                 scalar_tensor_default            | INITIA int64    1:2                  ACAA                 init7_s2_0_2
012 ~ | INITIA int64    1:2                  UYAA                 _val_581                         | INITIA int64    1:2                  GGAA                 init7_s2_32_32
013 ~ | INITIA int64                         AAAA                 dim_0__9                         | INITIA float32  1:1                  ?AAA                 init1_s1_
014 ~ | INITIA int64    1:1                  ZAAA                 _val_592                         | INITIA int64    1:2                  BCAA                 init7_s2_1_2
015 - | INITIA float32                       AAAA                 _val_522                         |
016 ~ | INITIA int64    1:2                  USAA                 _val_188                         | INITIA float32  1:1                  CAAA                 init1_s1_2
017 - | INITIA float32  3:1x1x1024           KAQG                 view_2                           |
018 - | INITIA float32  4:2x1x1024x1024      ????                 expand_1                         |
019 ~ | INITIA int64    1:3                  QKKA                 _val_533                         | INITIA int64    1:2                  UYAA                 init7_s2_2048_2000
020 ~ | INITIA int64    1:2                  GGAA                 splits                           | INITIA int64    1:2                  USAA                 init7_s2_2048_512
021 - | INITIA int64                         CAAA                 _val_588                         |
022 - | INITIA float32                       AAAA                 _val_560                         |
023 ~ | INITIA int64    1:4                  CKZM                 _val_233                         | INITIA int64    1:3                  CKSA                 init7_s3_2_1024_512
024 ~ | INITIA int64    1:3                  CKYA                 _val_572                         | INITIA int64    1:4                  CKZM                 init7_s4_2_1024_-1_64
025 ~ | INITIA int64    1:2                  GGAA                 splits_token_16                  | INITIA int64    1:3                  QKMA                 init7_s3_16_1024_64
026 ~ | INITIA int64    1:4                  CIKM                 _val_539                         | INITIA int64    1:3                  QMKA                 init7_s3_16_64_1024
027 ~ | INITIA int64    1:3                  QKMA                 _val_377                         | INITIA int64    1:3                  CKYA                 init7_s3_2_1024_2000
028 + |                                                                                            | INPUT  int64    2:2x1024             ZHFJ                 input0
029 = | INPUT  float32  2:4000x512           AXYM                 primals_2                        | INPUT  float32  2:4000x512           AXYM                 input1
030 - | INPUT  int64    2:2x1024             ZHFJ                 primals_1                        |
031 = | INPUT  float32  2:2x1024             BACA                 primals_3                        | INPUT  float32  2:2x1024             BACA                 input2
032 = | INPUT  float32  1:32                 DAAA                 primals_4                        | INPUT  float32  1:32                 DAAA                 input3
033 + |                                                                                            | INPUT  float32  1:512                YYYY                 input4
034 = | INPUT  float32  2:512x512            EHWU                 primals_6                        | INPUT  float32  2:512x512            EHWU                 input5
035 = | INPUT  float32  2:512x512            WEYY                 primals_7                        | INPUT  float32  2:512x512            WEYY                 input6
036 = | INPUT  float32  2:512x512            CBSF                 primals_8                        | INPUT  float32  2:512x512            CBSF                 input7
037 = | INPUT  float32  2:512x512            JXFC                 primals_9                        | INPUT  float32  2:512x512            JXFC                 input8
038 + |                                                                                            | INPUT  float32  1:512                YYYY                 input9
039 = | INPUT  float32  2:2000x512           KOCG                 primals_11                       | INPUT  float32  2:2000x512           KOCG                 input10
040 = | INPUT  float32  2:2000x512           DKHP                 primals_12                       | INPUT  float32  2:2000x512           DKHP                 input11
041 = | INPUT  float32  2:512x2000           STEQ                 primals_13                       | INPUT  float32  2:512x2000           STEQ                 input12
042 = | INPUT  float32  1:512                YYYY                 primals_5                        | INPUT  float32  1:512                YYYY                 input13
043 ~ | INPUT  float32  1:512                YYYY                 primals_10                       | RESULT float32  1:512                YYYY Identity        output_4
044 ~ | INPUT  float32  1:512                YYYY                 primals_14                       | RESULT float32  1:512                YYYY Identity        output_3
045 - | RESULT float32  2:512x512            EHWU Identity        t_33                             |
046 - | RESULT float32  2:512x512            WEYY Identity        t_29                             |
047 - | RESULT float32  2:512x512            CBSF Identity        t_25                             |
048 - | RESULT float32  2:512x512            JXFC Identity        t_21                             |
049 - | RESULT float32  2:2000x512           KOCG Identity        t_17                             |
050 - | RESULT float32  2:2000x512           DKHP Identity        t_13                             |
051 - | RESULT float32  2:512x2000           STEQ Identity        t_9                              |
052 ~ | RESULT float32  2:1x32               DAAA Unsqueeze       unsqueeze_7                      | RESULT float32  1:512                YYYY Identity        output_2
053 + |                                                                                            | RESULT int64    2:2x1024             ZHFJ Identity        output_1
054 = | RESULT float32  3:1x32x1             DAAA Unsqueeze       unsqueeze_8                      | RESULT float32  3:1x32x1             DAAA Unsqueeze       unsqueeze_8
055 = | RESULT float32  3:1x32x1024          EFXM MatMul          view_3                           | RESULT float32  3:1x32x1024          EFXM MatMul          bmm
056 = | RESULT float32  3:1x64x1024          JKJK Concat          Concat_392                       | RESULT float32  3:1x64x1024          JKJK Concat          cat_token_5
057 ~ | RESULT float32  3:1x1024x64          VFPY Transpose       cat                              | RESULT float32  3:1x64x1024          RMRM Sin             sin_token_7
058 ~ | RESULT float32  3:1x1024x64          GSEC Sin             sin                              | RESULT float32  4:1x1x64x1024        RMRM Unsqueeze       unsqueeze10
059 - | RESULT float32  4:1x1x1024x64        GSEC Unsqueeze       unsqueeze_11                     |
060 = | RESULT float32  4:1x1024x1x64        GSEC Transpose       Transpose_token_5_out0           | RESULT float32  4:1x1024x1x64        GSEC Transpose       Transpose_token_10_out0
061 = | RESULT float32  3:2x1024x512         BFOY Gather          embedding                        | RESULT float32  3:2x1024x512         BFOY Gather          output_5
062 = | RESULT float32  3:2x1024x512         BAAA Pow             pow_1                            | RESULT float32  3:2x1024x512         BAAA Pow             pow_1
063 = | RESULT float32  3:2x1024x1           AAAA ReduceMean      mean                             | RESULT float32  3:2x1024x1           AAAA ReduceMean      mean
064 = | RESULT float32  3:2x1024x1           AAAA Add             add_1                            | RESULT float32  3:2x1024x1           AAAA Add             add_1
065 = | RESULT float32  3:2x1024x1           KKKK Sqrt            _val_139                         | RESULT float32  3:2x1024x1           KKKK Sqrt            _onx_sqrt_add_10
066 = | RESULT float32  3:2x1024x1           UTBY Reciprocal      rsqrt                            | RESULT float32  3:2x1024x1           UTBY Reciprocal      output_6
067 = | RESULT float32  3:2x1024x512         NBBI Mul             mul_3                            | RESULT float32  3:2x1024x512         NBBI Mul             output_7
068 = | RESULT float32  3:2x1024x512         NBBI Mul             mul_4                            | RESULT float32  3:2x1024x512         NBBI Mul             mul_4
069 = | RESULT float32  2:2048x512           NBBI Reshape         view_4                           | RESULT float32  2:2048x512           NBBI Reshape         output_9
070 ~ | RESULT float32  2:2048x512           MTOA FusedMatMul     mm_1                             | RESULT float32  2:2048x512           MTOA Gemm            mm_1
071 - | RESULT float32  3:2x1024x512         MTOA Reshape         _unsafe_view_1                   |
072 = | RESULT float32  4:2x1024x8x64        MTOA Reshape         view_7                           | RESULT float32  4:2x1024x8x64        MTOA Reshape         view_7
073 = | RESULT float32  4:2x1024x8x32        IDGJ Split           Slice_460                        | RESULT float32  4:2x1024x8x32        IDGJ Split           SlicesSplitPattern--slice_Tensor
074 = | RESULT float32  4:2x1024x8x32        FPIR Split           Slice_477                        | RESULT float32  4:2x1024x8x32        FPIR Split           SlicesSplitPattern--slice_Tensor
075 = | RESULT float32  4:2x1024x8x32        VLSJ Neg             Neg_500                          | RESULT float32  4:2x1024x8x32        VLSJ Neg             neg2
076 = | RESULT float32  4:2x1024x8x64        EOXS Concat          Concat_508                       | RESULT float32  4:2x1024x8x64        EOXS Concat          cat3
077 = | RESULT float32  4:2x1024x8x64        TOQN Mul             Mul_521                          | RESULT float32  4:2x1024x8x64        TOQN Mul             mul_Tensor15
078 + |                                                                                            | RESULT float32  3:1x64x1024          NHNH Cos             cos_token_13
079 ~ | RESULT float32  3:1x1024x64          CJYF Cos             cos                              | RESULT float32  4:1x1x64x1024        NHNH Unsqueeze       unsqueeze9
080 - | RESULT float32  4:1x1x1024x64        CJYF Unsqueeze       unsqueeze_10                     |
081 = | RESULT float32  4:1x1024x1x64        CJYF Transpose       Transpose_token_7_out0           | RESULT float32  4:1x1024x1x64        CJYF Transpose       Transpose_token_16_out0
082 = | RESULT float32  4:2x1024x8x64        VBXL Mul             Mul_519                          | RESULT float32  4:2x1024x8x64        VBXL Mul             mul_Tensor14
083 = | RESULT float32  4:2x1024x8x64        OPNZ Add             Add_526                          | RESULT float32  4:2x1024x8x64        OPNZ Add             add_Tensor4
084 = | RESULT float32  4:2x8x64x1024        WHMZ Transpose       transpose_4                      | RESULT float32  4:2x8x64x1024        WHMZ Transpose       transpose_4
085 + |                                                                                            | RESULT float32  4:1x1x1024x64        GSEC Transpose       output_15
086 - | RESULT float32  3:16x64x1024         WHMZ Reshape         _unsafe_view_4                   |
087 ~ | RESULT float32  2:2048x512           YZMV FusedMatMul     mm                               | RESULT float32  2:2048x512           YZMV Gemm            mm
088 - | RESULT float32  3:2x1024x512         YZMV Reshape         _unsafe_view                     |
089 = | RESULT float32  4:2x1024x8x64        YZMV Reshape         view_5                           | RESULT float32  4:2x1024x8x64        YZMV Reshape         view_5
090 = | RESULT float32  4:2x8x1024x64        OIYJ Transpose       transpose_1                      | RESULT float32  4:2x8x1024x64        OIYJ Transpose       transpose_1
091 = | RESULT float32  4:2x8x1024x32        CQTL Split           slice_24                         | RESULT float32  4:2x8x1024x32        CQTL Split           slice_24
092 = | RESULT float32  4:2x8x1024x32        NTGY Split           slice_25                         | RESULT float32  4:2x8x1024x32        NTGY Split           slice_25
093 = | RESULT float32  4:2x8x1024x32        NHUC Neg             neg                              | RESULT float32  4:2x8x1024x32        NHUC Neg             neg
094 = | RESULT float32  4:2x8x1024x64        PXNO Concat          cat_1                            | RESULT float32  4:2x8x1024x64        PXNO Concat          cat_1
095 = | RESULT float32  4:2x8x1024x64        SQNQ Mul             mul_6                            | RESULT float32  4:2x8x1024x64        SQNQ Mul             mul_6
096 + |                                                                                            | RESULT float32  4:1x1x1024x64        CJYF Transpose       output_14
097 = | RESULT float32  4:2x8x1024x64        SVSM Mul             mul_5                            | RESULT float32  4:2x8x1024x64        SVSM Mul             mul_5
098 = | RESULT float32  4:2x8x1024x64        KKGC Add             add_2                            | RESULT float32  4:2x8x1024x64        KKGC Add             add_2
099 - | RESULT float32  3:16x1024x64         KKGC Reshape         _unsafe_view_3                   |
100 - | RESULT float32  3:16x1024x1024       UMRA MatMul          bmm_1                            |
101 - | RESULT float32  4:2x8x1024x1024      UMRA Reshape         view_10                          |
102 ~ | RESULT float32  4:2x8x1024x1024      SMYR Mul             mul_9                            | RESULT float32  4:2x8x1024x1024      SMYR FusedMatMul     _onx_mul_view_100
103 - | RESULT float32  3:2x1x1024           BACA Unsqueeze       unsqueeze_5                      |
104 = | RESULT float32  4:2x1x1x1024         BACA Unsqueeze       unsqueeze_6                      | RESULT float32  4:2x1x1x1024         BACA Unsqueeze       unsqueeze_6
105 = | RESULT float32  4:2x1x1024x1024      ???? Add             add                              | RESULT float32  4:2x1x1024x1024      ???? Add             add
106 = | RESULT bool     4:2x1x1024x1024      KWTE Equal           eq                               | RESULT bool     4:2x1x1024x1024      KWTE Equal           eq
107 = | RESULT float32  4:2x1x1024x1024      ???? Where           masked_fill                      | RESULT float32  4:2x1x1024x1024      ???? Where           masked_fill
108 - | RESULT float32  4:1024x1x2x1024      ???? Transpose       _val_414                         |
109 - | RESULT float32  4:1024x1x2x1024      ???? ScatterND       _val_416                         |
110 - | RESULT float32  4:1x2x1024x1024      ???? Transpose       _val_437                         |
111 - | RESULT float32  4:1x2x1024x1024      ???? ScatterND       _val_439                         |
112 - | RESULT float32  4:2x1x1024x1024      ???? Transpose       slice_scatter_1                  |
113 = | RESULT float32  4:2x8x1024x1024      ???? Add             add_4                            | RESULT float32  4:2x8x1024x1024      ???? Add             add_4
114 = | RESULT float32  4:2x8x1024x1024      OOON Softmax         _softmax                         | RESULT float32  4:2x8x1024x1024      OOON Softmax         output_18
115 - | RESULT float32  3:16x1024x1024       OOON Reshape         view_11                          |
116 ~ | RESULT float32  2:2048x512           PTHN FusedMatMul     mm_2                             | RESULT float32  2:2048x512           PTHN Gemm            mm_2
117 - | RESULT float32  3:2x1024x512         PTHN Reshape         _unsafe_view_2                   |
118 = | RESULT float32  4:2x1024x8x64        PTHN Reshape         view_9                           | RESULT float32  4:2x1024x8x64        PTHN Reshape         view_9
119 = | RESULT float32  4:2x8x1024x64        AHLI Transpose       transpose_3                      | RESULT float32  4:2x8x1024x64        AHLI Transpose       transpose_3
120 - | RESULT float32  3:16x1024x64         AHLI Reshape         _unsafe_view_5                   |
121 - | RESULT float32  3:16x1024x64         EFMR MatMul          bmm_2                            |
122 ~ | RESULT float32  4:2x8x1024x64        EFMR Reshape         view_12                          | RESULT float32  4:2x8x1024x64        EFMR MatMul          view_12
123 = | RESULT float32  4:2x1024x8x64        EEPN Transpose       transpose_5                      | RESULT float32  4:2x1024x8x64        EEPN Transpose       transpose_5
124 - | RESULT float32  3:2x1024x512         EEPN Reshape         view_13                          |
125 = | RESULT float32  2:2048x512           EEPN Reshape         view_14                          | RESULT float32  2:2048x512           EEPN Reshape         output_22
126 ~ | RESULT float32  2:2048x512           RRNI FusedMatMul     mm_3                             | RESULT float32  2:2048x512           RRNI Gemm            mm_3
127 = | RESULT float32  3:2x1024x512         RRNI Reshape         _unsafe_view_6                   | RESULT float32  3:2x1024x512         RRNI Reshape         _unsafe_view_6
128 = | RESULT float32  3:2x1024x512         SWCG Add             add_5                            | RESULT float32  3:2x1024x512         SWCG Add             output_23
129 = | RESULT float32  3:2x1024x512         DYPC Pow             pow_2                            | RESULT float32  3:2x1024x512         DYPC Pow             pow_2
130 = | RESULT float32  3:2x1024x1           VVLL ReduceMean      mean_1                           | RESULT float32  3:2x1024x1           VVLL ReduceMean      mean_1
131 = | RESULT float32  3:2x1024x1           VVLL Add             add_6                            | RESULT float32  3:2x1024x1           VVLL Add             add_6
132 = | RESULT float32  3:2x1024x1           BBZZ Sqrt            _val_562                         | RESULT float32  3:2x1024x1           BBZZ Sqrt            _onx_sqrt_add_60
133 = | RESULT float32  3:2x1024x1           GGHJ Reciprocal      rsqrt_1                          | RESULT float32  3:2x1024x1           GGHJ Reciprocal      output_24
134 = | RESULT float32  3:2x1024x512         SIYF Mul             mul_10                           | RESULT float32  3:2x1024x512         SIYF Mul             output_25
135 = | RESULT float32  3:2x1024x512         SIYF Mul             mul_11                           | RESULT float32  3:2x1024x512         SIYF Mul             mul_11
136 = | RESULT float32  2:2048x512           SIYF Reshape         view_15                          | RESULT float32  2:2048x512           SIYF Reshape         output_27
137 ~ | RESULT float32  2:2048x2000          VALD FusedMatMul     mm_4                             | RESULT float32  2:2048x2000          VALD Gemm            mm_4
138 = | RESULT float32  3:2x1024x2000        VALD Reshape         _unsafe_view_7                   | RESULT float32  3:2x1024x2000        VALD Reshape         output_28
139 ~ | RESULT float32  3:2x1024x2000        ZDZH QuickGelu       silu                             | RESULT float32  3:2x1024x2000        DEHS Sigmoid         _onx_sigmoid__unsafe_view_70
140 ~ | RESULT float32  2:2048x2000          EBAV FusedMatMul     mm_5                             | RESULT float32  2:2048x2000          DEHS Reshape         Reshape2Of3PatternR__onx_sigmoid
141 ~ | RESULT float32  3:2x1024x2000        EBAV Reshape         _unsafe_view_8                   | RESULT float32  2:2048x2000          ZDZH Mul             Reshape2Of3PatternL_output_29
142 ~ | RESULT float32  3:2x1024x2000        EWBL Mul             mul_12                           | RESULT float32  2:2048x2000          EBAV Gemm            mm_5
143 ~ | RESULT float32  2:2048x2000          EWBL Reshape         view_17                          | RESULT float32  2:2048x2000          EWBL Mul             output_34
144 ~ | RESULT float32  2:2048x512           SCST FusedMatMul     mm_6                             | RESULT float32  2:2048x512           SCST Gemm            mm_6
145 = | RESULT float32  3:2x1024x512         SCST Reshape         _unsafe_view_9                   | RESULT float32  3:2x1024x512         SCST Reshape         _unsafe_view_9
146 = | RESULT float32  3:2x1024x512         KXUY Add             add_7                            | RESULT float32  3:2x1024x512         KXUY Add             output_35
147 = | RESULT float32  3:2x1024x512         OOJQ Pow             pow_3                            | RESULT float32  3:2x1024x512         OOJQ Pow             pow_3
148 = | RESULT float32  3:2x1024x1           BBQQ ReduceMean      mean_2                           | RESULT float32  3:2x1024x1           BBQQ ReduceMean      mean_2
149 = | RESULT float32  3:2x1024x1           BBQQ Add             add_8                            | RESULT float32  3:2x1024x1           BBQQ Add             add_8
150 = | RESULT float32  3:2x1024x1           OONN Sqrt            _val_596                         | RESULT float32  3:2x1024x1           OONN Sqrt            _onx_sqrt_add_80
151 = | RESULT float32  3:2x1024x1           HHEH Reciprocal      rsqrt_2                          | RESULT float32  3:2x1024x1           HHEH Reciprocal      output_36
152 = | RESULT float32  3:2x1024x512         IPHM Mul             mul_13                           | RESULT float32  3:2x1024x512         IPHM Mul             output_37
153 = | RESULT float32  3:2x1024x512         IPHM Mul             mul_14                           | RESULT float32  3:2x1024x512         IPHM Mul             output_0
154 + |                                                                                            | RESULT float32  3:2x1024x2000        EBAV Reshape         output_32
155 + |                                                                                            | RESULT float32  3:2x1024x2000        ZDZH Reshape         output_29
156 + |                                                                                            | RESULT float32  2:2048x512           SIYF Identity        output_31
157 ~ | RESULT float32  3:16x1024x1024       OOON Transpose       transpose_7                      | RESULT float32  3:16x1024x1024       OOON Reshape         output_19
158 + |                                                                                            | RESULT float32  3:16x64x1024         WHMZ Reshape         output_17
159 + |                                                                                            | RESULT float32  3:16x1024x64         KKGC Reshape         output_16
160 - | RESULT float32  4:2x8x1024x1024      OOON Identity        detach_13                        |
161 ~ | RESULT float32  3:16x1024x64         WHMZ Transpose       transpose_10                     | RESULT float32  3:16x1024x64         AHLI Reshape         output_20
162 ~ | RESULT float32  3:16x64x1024         KKGC Transpose       transpose_9                      | RESULT float32  2:2048x512           NBBI Identity        output_11
163 ~ | RESULT float32  3:16x64x1024         AHLI Transpose       transpose_8                      | RESULT float32  2:2048x512           NBBI Identity        output_13
164 - | OUTPUT float32  3:2x1024x512         BFOY                 embedding                        |
165 ~ | OUTPUT float32  2:512x512            EHWU                 t_33                             | RESULT float32  2:512x512            ZVHA Transpose       output_8
166 ~ | OUTPUT float32  2:512x512            WEYY                 t_29                             | RESULT float32  2:512x512            BYCS Transpose       output_10
167 ~ | OUTPUT float32  2:512x512            CBSF                 t_25                             | RESULT float32  2:512x512            CEYX Transpose       output_12
168 ~ | OUTPUT float32  2:512x512            JXFC                 t_21                             | RESULT float32  2:512x512            FDBD Transpose       output_21
169 + |                                                                                            | RESULT float32  2:512x2000           CUNX Transpose       output_26
170 + |                                                                                            | RESULT float32  2:512x2000           ADXJ Transpose       output_30
171 ~ | OUTPUT float32  2:2000x512           KOCG                 t_17                             | RESULT float32  2:2000x512           UYAO Transpose       output_33
172 + |                                                                                            | OUTPUT float32  3:2x1024x512         IPHM                 output_0
173 + |                                                                                            | OUTPUT int64    2:2x1024             ZHFJ                 output_1
174 + |                                                                                            | OUTPUT float32  1:512                YYYY                 output_2
175 + |                                                                                            | OUTPUT float32  1:512                YYYY                 output_3
176 + |                                                                                            | OUTPUT float32  1:512                YYYY                 output_4
177 + |                                                                                            | OUTPUT float32  3:2x1024x512         BFOY                 output_5
178 - | OUTPUT float32  2:2000x512           DKHP                 t_13                             |
179 - | OUTPUT float32  2:512x2000           STEQ                 t_9                              |
180 = | OUTPUT float32  3:2x1024x1           UTBY                 rsqrt                            | OUTPUT float32  3:2x1024x1           UTBY                 output_6
181 + |                                                                                            | OUTPUT float32  3:2x1024x512         NBBI                 output_7
182 + |                                                                                            | OUTPUT float32  2:512x512            ZVHA                 output_8
183 = | OUTPUT float32  2:2048x512           NBBI                 view_4                           | OUTPUT float32  2:2048x512           NBBI                 output_9
184 + |                                                                                            | OUTPUT float32  2:512x512            BYCS                 output_10
185 - | OUTPUT float32  3:1x1024x64          VFPY                 cat                              |
186 ~ | OUTPUT float32  3:16x64x1024         AHLI                 transpose_8                      | OUTPUT float32  2:2048x512           NBBI                 output_11
187 + |                                                                                            | OUTPUT float32  2:512x512            CEYX                 output_12
188 ~ | OUTPUT float32  3:16x64x1024         KKGC                 transpose_9                      | OUTPUT float32  2:2048x512           NBBI                 output_13
189 + |                                                                                            | OUTPUT float32  4:1x1x1024x64        CJYF                 output_14
190 + |                                                                                            | OUTPUT float32  4:1x1x1024x64        GSEC                 output_15
191 ~ | OUTPUT float32  3:16x1024x64         WHMZ                 transpose_10                     | OUTPUT float32  3:16x1024x64         KKGC                 output_16
192 + |                                                                                            | OUTPUT float32  3:16x64x1024         WHMZ                 output_17
193 = | OUTPUT float32  4:2x8x1024x1024      OOON                 detach_13                        | OUTPUT float32  4:2x8x1024x1024      OOON                 output_18
194 = | OUTPUT float32  3:16x1024x1024       OOON                 transpose_7                      | OUTPUT float32  3:16x1024x1024       OOON                 output_19
195 + |                                                                                            | OUTPUT float32  3:16x1024x64         AHLI                 output_20
196 + |                                                                                            | OUTPUT float32  2:512x512            FDBD                 output_21
197 = | OUTPUT float32  2:2048x512           EEPN                 view_14                          | OUTPUT float32  2:2048x512           EEPN                 output_22
198 ~ | OUTPUT float32  2:2048x512           RRNI                 mm_3                             | OUTPUT float32  3:2x1024x512         SWCG                 output_23
199 = | OUTPUT float32  3:2x1024x1           GGHJ                 rsqrt_1                          | OUTPUT float32  3:2x1024x1           GGHJ                 output_24
200 + |                                                                                            | OUTPUT float32  3:2x1024x512         SIYF                 output_25
201 + |                                                                                            | OUTPUT float32  2:512x2000           CUNX                 output_26
202 = | OUTPUT float32  2:2048x512           SIYF                 view_15                          | OUTPUT float32  2:2048x512           SIYF                 output_27
203 ~ | OUTPUT float32  2:2048x2000          VALD                 mm_4                             | OUTPUT float32  3:2x1024x2000        VALD                 output_28
204 + |                                                                                            | OUTPUT float32  3:2x1024x2000        ZDZH                 output_29
205 + |                                                                                            | OUTPUT float32  2:512x2000           ADXJ                 output_30
206 + |                                                                                            | OUTPUT float32  2:2048x512           SIYF                 output_31
207 ~ | OUTPUT float32  2:2048x2000          EBAV                 mm_5                             | OUTPUT float32  3:2x1024x2000        EBAV                 output_32
208 + |                                                                                            | OUTPUT float32  2:2000x512           UYAO                 output_33
209 = | OUTPUT float32  2:2048x2000          EWBL                 view_17                          | OUTPUT float32  2:2048x2000          EWBL                 output_34
210 = | OUTPUT float32  3:2x1024x512         KXUY                 add_7                            | OUTPUT float32  3:2x1024x512         KXUY                 output_35
211 = | OUTPUT float32  3:2x1024x1           HHEH                 rsqrt_2                          | OUTPUT float32  3:2x1024x1           HHEH                 output_36
212 = | OUTPUT float32  3:2x1024x512         IPHM                 mul_14                           | OUTPUT float32  3:2x1024x512         IPHM                 output_37

Total running time of the script: (0 minutes 43.282 seconds)

Related examples

101: A custom backend for torch

101: A custom backend for torch

301: Compares LLAMA exporters

301: Compares LLAMA exporters

102: Fuse kernels in a small Llama Model

102: Fuse kernels in a small Llama Model

201: Evaluate DORT

201: Evaluate DORT

201: Evaluate different ways to export a torch model to ONNX

201: Evaluate different ways to export a torch model to ONNX

Gallery generated by Sphinx-Gallery