301: Compares LLAMA exporters for onnxrt backend

The script compares exported models in pytorch using onnxrt backend. It tries to do a side by side of the execution of both models.

To run the script:

python _doc/examples/plot_llama_diff_dort --help

The following example compares the forward step for mixed precision on cuda and produces all the intermediate onnx graphs.

python _doc/examples/plot_llama_diff_dort.py --part model --ortopt 1 --cuda 1 --backward 0 --mixed 1

You may use --mixed=1 to compare the backward graphs.

Some helpers

from experimental_experiment.args import get_parsed_args

script_args = get_parsed_args(
    "plot_llama_diff_export",
    description=__doc__,
    part=("attention", "one value among attention, decoder, model"),
    ortopt=(1, "run onnxruntime optimization"),
    backward=(0, "does one operator for backward"),
    cuda=(0, "use cuda or not"),
    mixed=(0, "use miwed precision"),
    opset=(18, "onnx opset"),
    expose="part,exporter,ortopt,cuda,mixed,opset",
)


import copy
import os
import warnings
import logging

try:
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        import onnxruntime

        has_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers()
except ImportError:
    print("onnxruntime not available.")
    import sys

    sys.exit(0)

import onnx
from onnx_array_api.reference import compare_onnx_execution, ExtendedReferenceEvaluator
import torch
from torch._dynamo.backends.common import aot_autograd
from experimental_experiment.ext_test_case import unit_test_going
from experimental_experiment.convert.convert_helper import (
    optimize_model_proto,
    ort_optimize,
)
from experimental_experiment.torch_models.llama_helper import (
    get_llama_model,
    get_llama_attention,
    get_llama_decoder,
)
from experimental_experiment.torch_models.dump_helper import (
    assert_all_close,
    dump_onnx,
    reorder_functions_in_proto,
    inputs_from_onnx_model,
    build_matching_inputs,
    results_to_string,
)
from experimental_experiment.torch_models.training_helper import (
    train_loop,
    make_aot_ort,
)
from experimental_experiment.torch_dynamo import (
    onnx_debug_backend,
    get_decomposition_table,
)

has_cuda = has_cuda and torch.cuda.is_available()
logging.disable(logging.ERROR)
provider = "cuda" if has_cuda else "cpu"

The exporting functions

print(f"part={script_args.part}")
ortopt = script_args.ortopt in (1, "1")
print(f"ortopt={ortopt}")
backward = script_args.backward in (1, "1")
print(f"backward={backward}")
use_cuda = script_args.cuda in (1, "1")
print(f"cuda={use_cuda}")
use_mixed = script_args.mixed in (1, "1")
print(f"mixed={use_mixed}")
opset = int(script_args.opset)
print(f"opset={opset}")
part=attention
ortopt=True
backward=False
cuda=False
mixed=False
opset=18

Model and data

if unit_test_going():
    kwargs = dict(input_dims=[(2, 1024)] * 2)
else:
    kwargs = dict(
        input_dims=[(2, 1024)] * 2,
        _attn_implementation="eager",
        num_hidden_layers=1,
        hidden_size=512,
        vocab_size=4000,
        intermediate_size=2000,
        max_position_embeddings=2048,
        num_attention_heads=8,
    )

if script_args.part == "attention":
    model, inputs = get_llama_attention(**kwargs)
elif script_args.part == "decoder":
    model, inputs = get_llama_decoder(**kwargs)
elif script_args.part == "model":
    model, inputs = get_llama_model(**kwargs)
else:
    raise RuntimeError(f"Unexpected value for part={script_args.part!r}")

if use_cuda:
    model = model.to("cuda")
    inputs = [[i.to("cuda") for i in inp] for inp in inputs]

print(f"simple run with {len(inputs)} inputs")
if backward:
    if use_mixed:
        assert use_cuda, "mixed precision only works with cuda"
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            torch.cuda.synchronize()
            expected = train_loop(copy.deepcopy(model), *inputs[0])
            torch.cuda.synchronize()
    else:
        expected = train_loop(copy.deepcopy(model), *inputs[0])
    print(
        f"-- eager mode worked, {len(expected)} gradients, first one is "
        f"{expected[0].shape}, {expected[0].dtype}"
    )
else:
    if use_mixed:
        assert use_cuda, "mixed precision only works with cuda"
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            torch.cuda.synchronize()
            expected = model(*inputs[0])
            torch.cuda.synchronize()
    else:
        expected = model(*inputs[0])
    print(results_to_string(expected))
simple run with 2 inputs
torch.float32 (2, 1024, 512) [sum=595]

Exporting

folder = "dump_models"
storage = {}

if backward:
    # onnxrt backend
    local_aot_ort, _ = make_aot_ort(dynamic=False, rewrite=True)

    optimized_mod = torch.compile(
        copy.deepcopy(model), backend=local_aot_ort, dynamic=False, fullgraph=True
    )

    with dump_onnx("llama_onnxrt", folder=folder, clean=True):
        if use_mixed:
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                torch.cuda.synchronize()
                expected_onnxrt = train_loop(optimized_mod, *inputs[0])
                torch.cuda.synchronize()
        else:
            expected_onnxrt = train_loop(optimized_mod, *inputs[0])
    assert_all_close(expected[0], expected_onnxrt[0], atol=1e-3)
    print(
        f"-- onnxrt backend worked, {len(expected_onnxrt)} gradients, first one is "
        f"{expected_onnxrt[0].shape}, {expected_onnxrt[0].dtype}"
    )

    # debugging backend
    aot_compiler = aot_autograd(
        fw_compiler=lambda *args, **kwargs: onnx_debug_backend(
            *args,
            dump_prefix=os.path.join(folder, "llama_debug"),
            target_opset=opset,
            storage=storage,
            **kwargs,
        ),
        decompositions=get_decomposition_table(),
    )
    onnx_mod = torch.compile(copy.deepcopy(model), backend=aot_compiler, fullgraph=True)

    if False and use_mixed:
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            torch.cuda.synchronize()
            got = train_loop(onnx_mod, *inputs[0])
            torch.cuda.synchronize()
    else:
        got = train_loop(onnx_mod, *inputs[0])
    assert_all_close(expected[0], got[0], atol=1e-2 if use_mixed else 1e-4)
    print(
        f"-- debug backend worked, {len(got)} gradients, first one is "
        f"{got[0].shape}, {got[0].dtype}"
    )

else:
    # onnxrt backend
    local_aot_ort, _ = make_aot_ort(dynamic=True, rewrite=True)
    optimized_mod = torch.compile(model, backend=local_aot_ort, fullgraph=True)
    with dump_onnx("llama_onnxrt", folder=folder, clean=True):
        if use_mixed:
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                torch.cuda.synchronize()
                expected_onnxrt = optimized_mod(*inputs[0])
                torch.cuda.synchronize()
        else:
            expected_onnxrt = optimized_mod(*inputs[0])
    assert_all_close(expected, expected_onnxrt, atol=1e-2)

    # debugging backend
    aot_compiler = aot_autograd(
        fw_compiler=lambda *args, **kwargs: onnx_debug_backend(
            *args,
            dump_prefix=os.path.join(folder, "llama_debug"),
            target_opset=17,
            storage=storage,
            **kwargs,
        )
    )

    onnx_mod = torch.compile(model, backend=aot_compiler, fullgraph=True)
    if use_mixed:
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            got = onnx_mod(*inputs[0])
    else:
        got = onnx_mod(*inputs[0])
    assert_all_close(expected, got, atol=1 if use_mixed else 1e-3)
/home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/exporter.py:137: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
  warnings.warn(
Applied 0 pattern rewrite rules.
Applied 0 pattern rewrite rules.

For forward, there are two files, one onnx model and the graph module printed in a txt file. For backward, there are two onnx models. Then it is multiplied by the number of backends.

models = os.listdir(folder)
print(f"exported models: {models}")
exported models: ['llama_onnxrt_0.onnx', 'llama_debug_0.onnx', 'llama_debug_0.txt', 'llama_onnxrt_0.txt']

Inputs used by the debug backend

feeds = storage["instance"][0]["inputs"][0]
for k, v in feeds.items():
    print(f"-- {k} {v.dtype} {v.shape}")
-- input0 float32 (512, 512)
-- input1 float32 (512, 512)
-- input2 float32 (512, 512)
-- input3 float32 (512, 512)
-- input4 float32 (2048, 64)
-- input5 float32 (2048, 64)
-- input6 float32 (2, 1024, 512)
-- input7 int64 (1, 1024)
-- input8 float32 (2, 1, 1024, 1024)

Let’s the first line of the graph module

graph_module = storage["instance"][0]["graph_module"]
print("\n".join(str(graph_module.graph).split("\n")[:10]))
graph():
    %primals_1 : [num_users=1] = placeholder[target=primals_1]
    %primals_2 : [num_users=1] = placeholder[target=primals_2]
    %primals_3 : [num_users=1] = placeholder[target=primals_3]
    %primals_4 : [num_users=1] = placeholder[target=primals_4]
    %primals_5 : [num_users=1] = placeholder[target=primals_5]
    %primals_6 : [num_users=1] = placeholder[target=primals_6]
    %primals_7 : [num_users=3] = placeholder[target=primals_7]
    %primals_8 : [num_users=2] = placeholder[target=primals_8]
    %primals_9 : [num_users=1] = placeholder[target=primals_9]

Comparison and execution

if backward:
    print(f"-- {len(storage['instance'])} onnx models were creates")
    for i, inst in enumerate(storage["instance"]):
        print(f"  model {i}: {len(inst['inputs'])} runs")

    # deal with backward
    onnx_models = list(sorted([m for m in models if m.endswith(".onnx")]))
    assert len(onnx_models) == 4, f"unexpected value {onnx_models}"
    onnx_models = list(sorted([m for m in models if m.endswith(".onnx") and "_1" in m]))
    assert len(onnx_models) == 2, f"unexpected value {onnx_models}"
    model_onnxrt = os.path.join(folder, onnx_models[1])
    model_debug = os.path.join(folder, onnx_models[0])
else:
    onnx_models = list(sorted([m for m in models if m.endswith(".onnx")]))
    if len(onnx_models) == 2:
        model_onnxrt = os.path.join(folder, onnx_models[1])
        model_debug = os.path.join(folder, onnx_models[0])
    else:
        model_debug = os.path.join(folder, onnx_models[0])
        # the following error may appear:
        # Node type 'Rank' from domain 'pkg.onnxscript.torch_lib.common' is unknown
        print(f"One model is missing, onnx_models={onnx_models}")
        model_onnxrt = model_debug

print(f"model_onnxrt={model_onnxrt}")
print(f"model_debug={model_debug}")
model_onnxrt=dump_models/llama_onnxrt_0.onnx
model_debug=dump_models/llama_debug_0.onnx

The inputs of both models

onnxrt: [('INPUT', 'primals_4', 1, (512, 512)), ('INPUT', 'primals_1', 1, (512, 512)), ('INPUT', 'primals_7', 1, (2, 1024, 512)), ('INPUT', 'primals_2', 1, (512, 512)), ('INPUT', 'primals_3', 1, (512, 512)), ('INPUT', 'primals_5', 1, (2048, 64)), ('INPUT', 'primals_6', 1, (2048, 64)), ('INPUT', 'primals_8', 7, (1, 1024)), ('INPUT', 'primals_9', 1, (2, 1, 1024, 1024))]
debug: [('INPUT', 'input0', 1, (512, 512)), ('INPUT', 'input1', 1, (512, 512)), ('INPUT', 'input2', 1, (512, 512)), ('INPUT', 'input3', 1, (512, 512)), ('INPUT', 'input4', 1, (2048, 64)), ('INPUT', 'input5', 1, (2048, 64)), ('INPUT', 'input6', 1, (2, 1024, 512)), ('INPUT', 'input7', 7, (1, 1024)), ('INPUT', 'input8', 1, (2, 1, 1024, 1024))]

Inputs are not the same. The first model has more and some inputs were moved into the initializer list into for model_debug.

print("debug:", inputs_from_onnx_model(model_debug, init=True))
debug: [('INPUT', 'input0', 1, (512, 512)), ('INPUT', 'input1', 1, (512, 512)), ('INPUT', 'input2', 1, (512, 512)), ('INPUT', 'input3', 1, (512, 512)), ('INPUT', 'input4', 1, (2048, 64)), ('INPUT', 'input5', 1, (2048, 64)), ('INPUT', 'input6', 1, (2, 1024, 512)), ('INPUT', 'input7', 7, (1, 1024)), ('INPUT', 'input8', 1, (2, 1, 1024, 1024)), ('INIT', 'init7_s2_2048_512', 7, (2,)), ('INIT', 'init7_s3_2_1024_512', 7, (3,)), ('INIT', 'init7_s4_2_1024_8_64', 7, (4,)), ('INIT', 'init7_s1_0', 7, (1,)), ('INIT', 'init7_s1_1024', 7, (1,)), ('INIT', 'init7_s1_1', 7, (1,)), ('INIT', 'init7_s3_16_1024_64', 7, (3,)), ('INIT', 'init7_s3_16_64_1024', 7, (3,)), ('INIT', 'init1_s_', 1, ()), ('INIT', 'init7_s3_16_1024_1024', 7, (3,)), ('INIT', 'init7_s2_32_32', 7, (2,))]

Optimization and Verification

Let’s try the model with a python backend (reference implementation). First step, onnx-script uses many functions. The reference evaluation expects every function to be defined so the order of functions in the model matters. No recursivity is allowed by this runtime. We need to reorder as function Rank is usually placed at the end of the model.

'dump_models/llama_onnxrt_0.onnx'

Let’s load the model and optimize them.

debug = onnx.load(model_debug)
try:
    onnxrt = optimize_model_proto(onnx.load(model_onnxrt))
except ImportError as e:
    print("missing library", e)
    onnxrt = debug
Applied 0 pattern rewrite rules.

Let’s apply onnxruntime optimization

if ortopt:
    providers = (
        [("CUDAExecutionProvider", {}), ("CPUExecutionProvider", {})]
        if use_cuda
        else ["CPUExecutionProvider"]
    )
    with open(model_onnxrt.replace(".onnx", ".before.opt.onnx"), "wb") as f:
        f.write(onnxrt.SerializeToString())
    print(f"run onnxruntime optimization on {model_onnxrt}")
    optimized = model_onnxrt.replace(".onnx", ".opt.onnx")
    ort_optimize(onnxrt, output=optimized, providers=providers)
    onnxrt = onnx.load(optimized)

    print(f"run onnxruntime optimization on {model_debug}")
    optimized = model_debug.replace(".onnx", ".opt.onnx")
    ort_optimize(debug, output=optimized, disable_aot=True, providers=providers)
    debug = onnx.load(optimized)
run onnxruntime optimization on dump_models/llama_onnxrt_0.onnx
run onnxruntime optimization on dump_models/llama_debug_0.onnx

For what’s following, we need to build two lists of matching inputs.

print("build_matching_inputs")
feedsrt = build_matching_inputs(model_debug, feeds, model_onnxrt)
print("done")
build_matching_inputs
done

We check both models are running.

out_onnxrt = ExtendedReferenceEvaluator(onnxrt).run(None, feedsrt)
out_debug = ExtendedReferenceEvaluator(debug).run(None, feeds)
assert out_onnxrt
assert out_debug

# assert_all_close(out_onnxrt, out_debug)

Side by side

res1, res2, align, dc = compare_onnx_execution(
    onnxrt,
    debug,
    verbose=1,
    raise_exc=True,
    inputs=(feedsrt, feeds),
)
text = dc.to_str(res1, res2, align, column_size=90)
print(text)
[compare_onnx_execution] execute with 2 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 103 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 79 results
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 108 pairs
[compare_onnx_execution] done
001 = | INITIA int64    1:2                  USAA                 ortshared_7_1_2_0_token_175      | INITIA int64    1:2                  USAA                 ortshared_7_1_2_0_token_99
002 - | INITIA int64    1:4                  CIKM                 ortshared_7_1_4_1_token_171      |
003 - | INITIA int64    1:1                  KAAA                 ortshared_7_1_1_5_token_180      |
004 ~ | INITIA int64    1:3                  QKMA                 ortshared_7_1_3_0_token_167      | INITIA int64    1:3                  CKSA                 ortshared_7_1_3_0_token_98
005 ~ | INITIA int64    1:3                  QMKA                 ortshared_7_1_3_1_token_168      | INITIA int64    1:4                  CKIM                 ortshared_7_1_4_0_token_100
006 = | INITIA int64    1:1                  AAAA                 ortshared_7_1_1_3_token_169      | INITIA int64    1:1                  AAAA                 ortshared_7_1_1_0_token_97
007 ~ | INITIA int64    1:2                  GGAA                 splits                           | INITIA int64    1:1                  KAAA                 ortshared_7_1_1_2_token_106
008 = | INITIA int64    1:1                  BAAA                 ortshared_7_1_1_1_token_163      | INITIA int64    1:1                  BAAA                 ortshared_7_1_1_1_token_105
009 - | INITIA float32                       IAAA                 ortshared_1_0_1_1_token_177      |
010 ~ | INITIA int64    1:2                  GGAA                 splits_token_181                 | INITIA int64    1:3                  QKMA                 ortshared_7_1_3_1_token_102
011 - | INITIA int64                         ZAAA                 ortshared_7_0_1_1_token_176      |
012 - | INITIA int64    1:4                  CIKK                 ortshared_7_1_4_2_token_174      |
013 - | INITIA float32                       BAAA                 ortshared_1_0_1_0_token_172      |
014 ~ | INITIA int64    1:3                  QKKA                 ortshared_7_1_3_3_token_178      | INITIA int64    1:3                  QMKA                 ortshared_7_1_3_3_token_107
015 - | INITIA int64                         BAAA                 ortshared_7_0_1_0_token_164      |
016 - | INITIA int64    1:4                  CKIM                 ortshared_7_1_4_0_token_165      |
017 ~ | INITIA int64    1:2                  BKAA                 ortshared_7_1_2_1_token_179      | INITIA int64    1:2                  GGAA                 ortshared_7_1_2_1_token_101
018 ~ | INITIA int64    1:3                  CKSA                 ortshared_7_1_3_2_token_173      | INITIA int64    1:3                  QKKA                 ortshared_7_1_3_2_token_103
019 = | INPUT  float32  2:512x512            UCQB                 primals_4                        | INPUT  float32  2:512x512            UCQB                 input0
020 = | INPUT  float32  2:512x512            URYC                 primals_1                        | INPUT  float32  2:512x512            URYC                 input1
021 - | INPUT  float32  3:2x1024x512         YWBT                 primals_7                        |
022 = | INPUT  float32  2:512x512            VBXD                 primals_2                        | INPUT  float32  2:512x512            VBXD                 input2
023 = | INPUT  float32  2:512x512            AUCY                 primals_3                        | INPUT  float32  2:512x512            AUCY                 input3
024 = | INPUT  float32  2:2048x64            MDRB                 primals_5                        | INPUT  float32  2:2048x64            MDRB                 input4
025 = | INPUT  float32  2:2048x64            ZHDU                 primals_6                        | INPUT  float32  2:2048x64            ZHDU                 input5
026 + |                                                                                            | INPUT  float32  3:2x1024x512         YWBT                 input6
027 = | INPUT  int64    2:1x1024             KAQG                 primals_8                        | INPUT  int64    2:1x1024             KAQG                 input7
028 = | INPUT  float32  4:2x1x1024x1024      AAAA                 primals_9                        | INPUT  float32  4:2x1x1024x1024      AAAA                 input8
029 - | RESULT float32  2:512x512            UCQB Identity        t_6                              |
030 - | RESULT float32  4:2x1x1024x1024      AAAA Mul             _inlfunc_aten_add|folded_2_other |
031 - | RESULT int64    2:1x1024             KAQG Expand          _val_65                          |
032 - | RESULT int64    3:1x1024x1           KAQG Unsqueeze       _val_67                          |
033 - | RESULT int64    3:1x1024x1           KAQG Concat          _val_68                          |
034 = | RESULT float32  2:1024x64            GSEC Slice           slice_2                          | RESULT float32  2:1024x64            GSEC Slice           slice_2
035 - | RESULT float32  2:1024x64            GSEC Transpose       _val_62                          |
036 ~ | RESULT float32  3:1x1024x64          GSEC GatherND        _val_69                          | RESULT float32  3:1x1024x64          GSEC Gather          index_1
037 = | RESULT float32  4:1x1x1024x64        GSEC Unsqueeze       aten_unsqueeze_116_n2            | RESULT float32  4:1x1x1024x64        GSEC Unsqueeze       output_5
038 = | RESULT float32  4:1x1024x1x64        GSEC Transpose       Transpose_token_5_out0           | RESULT float32  4:1x1024x1x64        GSEC Transpose       Transpose_token_4_out0
039 = | RESULT float32  2:2048x512           YWBT Reshape         view                             | RESULT float32  2:2048x512           YWBT Reshape         output_2
040 ~ | RESULT float32  2:2048x512           FVMX FusedMatMul     mm_1                             | RESULT float32  2:2048x512           XOOY Gemm            mm_1
041 - | RESULT float32  3:2x1024x512         FVMX Reshape         view_3                           |
042 ~ | RESULT float32  4:2x1024x8x64        FVMX Reshape         view_7                           | RESULT float32  4:2x1024x8x64        XOOY Reshape         view_7
043 ~ | RESULT float32  4:2x1024x8x32        AMAV Split           Slice_178                        | RESULT float32  4:2x1024x8x32        KQNE Split           SlicesSplitPattern--slice_Tensor
044 ~ | RESULT float32  4:2x1024x8x32        GKNC Split           Slice_195                        | RESULT float32  4:2x1024x8x32        NZBV Split           SlicesSplitPattern--slice_Tensor
045 ~ | RESULT float32  4:2x1024x8x32        UQNY Neg             aten_neg_199_n0                  | RESULT float32  4:2x1024x8x32        NBZF Neg             neg2
046 ~ | RESULT float32  4:2x1024x8x64        VCMT Concat          aten_cat_204_n0                  | RESULT float32  4:2x1024x8x64        XRMK Concat          cat2
047 ~ | RESULT float32  4:2x1024x8x64        NKQX Mul             aten_mul_208_n0                  | RESULT float32  4:2x1024x8x64        PBCM Mul             mul4
048 = | RESULT float32  2:1024x64            CJYF Slice           slice_1                          | RESULT float32  2:1024x64            CJYF Slice           slice_1
049 - | RESULT float32  2:1024x64            CJYF Transpose       _val_53                          |
050 ~ | RESULT float32  3:1x1024x64          CJYF GatherND        _val_60                          | RESULT float32  3:1x1024x64          CJYF Gather          index
051 = | RESULT float32  4:1x1x1024x64        CJYF Unsqueeze       aten_unsqueeze_115_n2            | RESULT float32  4:1x1x1024x64        CJYF Unsqueeze       output_4
052 = | RESULT float32  4:1x1024x1x64        CJYF Transpose       Transpose_token_8_out0           | RESULT float32  4:1x1024x1x64        CJYF Transpose       Transpose_token_6_out0
053 ~ | RESULT float32  4:2x1024x8x64        ALUG Mul             aten_mul_161_n0                  | RESULT float32  4:2x1024x8x64        KMQI Mul             mul3
054 ~ | RESULT float32  4:2x1024x8x64        MVKD Add             _inlfunc_aten_add|folded_1_n3    | RESULT float32  4:2x1024x8x64        YOSU Add             add_Tensor2
055 ~ | RESULT float32  4:2x8x64x1024        FCAN Transpose       transpose_3                      | RESULT float32  4:2x8x64x1024        KCYO Transpose       transpose_3
056 - | RESULT float32  3:16x64x1024         FCAN Reshape         view_10                          |
057 - | RESULT float32  4:1x1x1024x64        GSEC Transpose       unsqueeze_1                      |
058 ~ | RESULT float32  2:2048x512           XOOY FusedMatMul     mm                               | RESULT float32  2:2048x512           YWBT Reshape         output_1
059 ~ | RESULT float32  3:2x1024x512         XOOY Reshape         view_1                           | RESULT float32  2:2048x512           LECY Gemm            mm
060 ~ | RESULT float32  4:2x1024x8x64        XOOY Reshape         view_6                           | RESULT float32  4:2x1024x8x64        LECY Reshape         view_6
061 ~ | RESULT float32  4:2x8x1024x64        DJQW Transpose       transpose                        | RESULT float32  4:2x8x1024x64        JGZB Transpose       transpose
062 ~ | RESULT float32  4:2x8x1024x32        ZBVW Split           slice_3                          | RESULT float32  4:2x8x1024x32        VEUG Split           slice_3
063 ~ | RESULT float32  4:2x8x1024x32        DIVA Split           slice_4                          | RESULT float32  4:2x8x1024x32        NCGV Split           slice_4
064 ~ | RESULT float32  4:2x8x1024x32        XSFA Neg             neg                              | RESULT float32  4:2x8x1024x32        NYUF Neg             neg
065 ~ | RESULT float32  4:2x8x1024x64        VSBW Concat          cat                              | RESULT float32  4:2x8x1024x64        IDPL Concat          cat
066 ~ | RESULT float32  4:2x8x1024x64        OCLC Mul             mul_1                            | RESULT float32  4:2x8x1024x64        YARS Mul             mul_1
067 - | RESULT float32  4:1x1x1024x64        CJYF Transpose       unsqueeze                        |
068 ~ | RESULT float32  4:2x8x1024x64        VANM Mul             mul                              | RESULT float32  4:2x8x1024x64        OTPY Mul             mul
069 ~ | RESULT float32  4:2x8x1024x64        KCYO Add             add                              | RESULT float32  4:2x8x1024x64        NSFP Add             add
070 - | RESULT float32  3:16x1024x64         KCYO Reshape         view_9                           |
071 - | RESULT float32  3:16x1024x1024       MSON MatMul          bmm                              |
072 ~ | RESULT float32  4:2x8x1024x1024      MSON Reshape         view_11                          | RESULT float32  4:2x8x1024x1024      QPIE FusedMatMul     div
073 - | RESULT float32  4:2x8x1024x1024      YGCR Div             div                              |
074 ~ | RESULT float32  4:2x8x1024x1024      YGCR Add             add_2                            | RESULT float32  4:2x8x1024x1024      QPIE Add             add_2
075 ~ | RESULT float32  4:2x8x1024x1024      ONNN Softmax         _softmax                         | RESULT float32  4:2x8x1024x1024      ONNO Softmax         output_8
076 - | RESULT float32  3:16x1024x1024       ONNN Reshape         view_12                          |
077 ~ | RESULT float32  2:2048x512           MQUP FusedMatMul     mm_2                             | RESULT float32  2:2048x512           YWBT Reshape         output_3
078 ~ | RESULT float32  3:2x1024x512         MQUP Reshape         view_5                           | RESULT float32  2:2048x512           FVMX Gemm            mm_2
079 ~ | RESULT float32  4:2x1024x8x64        MQUP Reshape         view_8                           | RESULT float32  4:2x1024x8x64        FVMX Reshape         view_8
080 ~ | RESULT float32  4:2x8x1024x64        IUHD Transpose       transpose_2                      | RESULT float32  4:2x8x1024x64        ZCMY Transpose       transpose_2
081 ~ | RESULT float32  3:16x1024x64         IUHD Reshape         view_13                          | RESULT float32  4:2x8x1024x64        VPTA MatMul          view_11
082 ~ | RESULT float32  3:16x1024x64         IUNZ MatMul          bmm_1                            | RESULT float32  4:2x1024x8x64        FFHL Transpose       transpose_4
083 ~ | RESULT float32  4:2x8x1024x64        IUNZ Reshape         view_14                          | RESULT float32  2:2048x512           FFHL Reshape         output_12
084 ~ | RESULT float32  4:2x1024x8x64        SKKC Transpose       transpose_4                      | RESULT float32  2:2048x512           GDEI Gemm            mm_3
085 ~ | RESULT float32  3:2x1024x512         SKKC Reshape         view_15                          | RESULT float32  3:2x1024x512         GDEI Reshape         output_0
086 + |                                                                                            | RESULT float32  2:512x512            CXYY Transpose       output_11
087 ~ | RESULT float32  2:2048x512           SKKC Reshape         view_16                          | RESULT float32  3:16x1024x64         ZCMY Reshape         output_10
088 - | RESULT float32  2:2048x512           FJWU FusedMatMul     mm_3                             |
089 - | RESULT float32  3:2x1024x512         FJWU Reshape         view_17                          |
090 ~ | RESULT float32  3:16x1024x1024       ONNN Transpose       transpose_6                      | RESULT float32  3:16x1024x1024       ONNO Reshape         output_9
091 + |                                                                                            | RESULT float32  3:16x64x1024         KCYO Reshape         output_7
092 - | RESULT float32  4:2x8x1024x1024      ONNN Identity        detach_3                         |
093 ~ | RESULT float32  3:16x1024x64         FCAN Transpose       transpose_9                      | RESULT float32  3:16x1024x64         NSFP Reshape         output_6
094 + |                                                                                            | OUTPUT float32  3:2x1024x512         GDEI                 output_0
095 ~ | RESULT float32  3:16x64x1024         KCYO Transpose       transpose_8                      | OUTPUT float32  2:2048x512           YWBT                 output_1
096 ~ | RESULT float32  3:16x64x1024         IUHD Transpose       transpose_7                      | OUTPUT float32  2:2048x512           YWBT                 output_2
097 = | OUTPUT float32  2:2048x512           YWBT                 view                             | OUTPUT float32  2:2048x512           YWBT                 output_3
098 - | OUTPUT float32  2:512x512            UCQB                 t_6                              |
099 = | OUTPUT float32  4:1x1x1024x64        CJYF                 unsqueeze                        | OUTPUT float32  4:1x1x1024x64        CJYF                 output_4
100 = | OUTPUT float32  4:1x1x1024x64        GSEC                 unsqueeze_1                      | OUTPUT float32  4:1x1x1024x64        GSEC                 output_5
101 ~ | OUTPUT float32  3:16x64x1024         IUHD                 transpose_7                      | OUTPUT float32  3:16x1024x64         NSFP                 output_6
102 = | OUTPUT float32  3:16x64x1024         KCYO                 transpose_8                      | OUTPUT float32  3:16x64x1024         KCYO                 output_7
103 - | OUTPUT float32  3:16x1024x64         FCAN                 transpose_9                      |
104 ~ | OUTPUT float32  4:2x8x1024x1024      ONNN                 detach_3                         | OUTPUT float32  4:2x8x1024x1024      ONNO                 output_8
105 ~ | OUTPUT float32  3:16x1024x1024       ONNN                 transpose_6                      | OUTPUT float32  3:16x1024x1024       ONNO                 output_9
106 ~ | OUTPUT float32  2:2048x512           SKKC                 view_16                          | OUTPUT float32  3:16x1024x64         ZCMY                 output_10
107 + |                                                                                            | OUTPUT float32  2:512x512            CXYY                 output_11
108 ~ | OUTPUT float32  3:2x1024x512         FJWU                 view_17                          | OUTPUT float32  2:2048x512           FFHL                 output_12

Total running time of the script: (0 minutes 6.988 seconds)

Gallery generated by Sphinx-Gallery