301: Compares LLAMA exporters for onnxrt backend

The script compares exported models in pytorch using onnxrt backend. It tries to do a side by side of the execution of both models.

To run the script:

python _doc/examples/plot_llama_diff_dort --help

The following example compares the forward step for mixed precision on cuda and produces all the intermediate onnx graphs.

python _doc/examples/plot_llama_diff_dort.py --part model --ortopt 1 \
        --cuda 1 --backward 0 --mixed 1

You may use --mixed=1 to compare the backward graphs.

Some helpers

from experimental_experiment.args import get_parsed_args

script_args = get_parsed_args(
    "plot_llama_diff_export",
    description=__doc__,
    part=("attention", "one value among attention, decoder, model"),
    ortopt=(1, "run onnxruntime optimization"),
    backward=(0, "does one operator for backward"),
    cuda=(0, "use cuda or not"),
    mixed=(0, "use miwed precision"),
    opset=(18, "onnx opset"),
    expose="part,exporter,ortopt,cuda,mixed,opset",
)


import copy
import os
import warnings
import logging

try:
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        import onnxruntime

        has_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers()
except ImportError:
    print("onnxruntime not available.")
    import sys

    sys.exit(0)

import onnx
from onnx_array_api.reference import compare_onnx_execution, ExtendedReferenceEvaluator
import torch
from torch._dynamo.backends.common import aot_autograd
from experimental_experiment.ext_test_case import unit_test_going
from experimental_experiment.convert.convert_helper import (
    optimize_model_proto_oxs,
    ort_optimize,
)
from experimental_experiment.torch_models.llama_helper import (
    get_llama_model,
    get_llama_attention,
    get_llama_decoder,
)
from experimental_experiment.torch_models.dump_helper import (
    assert_all_close,
    dump_onnx,
    reorder_functions_in_proto,
    inputs_from_onnx_model,
    build_matching_inputs,
    results_to_string,
)
from experimental_experiment.torch_models.training_helper import (
    train_loop,
    make_aot_ort,
)
from experimental_experiment.torch_dynamo import (
    onnx_debug_backend,
    get_decomposition_table,
)

has_cuda = has_cuda and torch.cuda.is_available()
logging.disable(logging.ERROR)
provider = "cuda" if has_cuda else "cpu"

The exporting functions

print(f"part={script_args.part}")
ortopt = script_args.ortopt in (1, "1")
print(f"ortopt={ortopt}")
backward = script_args.backward in (1, "1")
print(f"backward={backward}")
use_cuda = script_args.cuda in (1, "1")
print(f"cuda={use_cuda}")
use_mixed = script_args.mixed in (1, "1")
print(f"mixed={use_mixed}")
opset = int(script_args.opset)
print(f"opset={opset}")
part=attention
ortopt=True
backward=False
cuda=False
mixed=False
opset=18

Model and data

if unit_test_going():
    kwargs = dict(input_dims=[(2, 1024)] * 2)
else:
    kwargs = dict(
        input_dims=[(2, 1024)] * 2,
        _attn_implementation="eager",
        num_hidden_layers=1,
        hidden_size=512,
        vocab_size=4000,
        intermediate_size=2000,
        max_position_embeddings=2048,
        num_attention_heads=8,
    )

if script_args.part == "attention":
    model, inputs = get_llama_attention(**kwargs)
elif script_args.part == "decoder":
    model, inputs = get_llama_decoder(**kwargs)
elif script_args.part == "model":
    model, inputs = get_llama_model(**kwargs)
else:
    raise RuntimeError(f"Unexpected value for part={script_args.part!r}")

if use_cuda:
    model = model.to("cuda")
    inputs = [[i.to("cuda") for i in inp] for inp in inputs]

print(f"simple run with {len(inputs)} inputs")
if backward:
    if use_mixed:
        assert use_cuda, "mixed precision only works with cuda"
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            torch.cuda.synchronize()
            expected = train_loop(copy.deepcopy(model), *inputs[0])
            torch.cuda.synchronize()
    else:
        expected = train_loop(copy.deepcopy(model), *inputs[0])
    print(
        f"-- eager mode worked, {len(expected)} gradients, first one is "
        f"{expected[0].shape}, {expected[0].dtype}"
    )
else:
    if use_mixed:
        assert use_cuda, "mixed precision only works with cuda"
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            torch.cuda.synchronize()
            expected = model(*inputs[0])
            torch.cuda.synchronize()
    else:
        expected = model(*inputs[0])
    print(results_to_string(expected))
simple run with 2 inputs
torch.float32 (2, 1024, 512) [sum=-250]

Exporting

if hasattr(torch._dynamo.variables.misc, "LoggingLoggerVariable"):
    # A tweak to make torch.export.export work.
    torch._dynamo.variables.misc.LoggingLoggerVariable.call_method = lambda *_, **__: None


folder = "dump_models"
storage = {}

if backward:
    # onnxrt backend
    local_aot_ort, _ = make_aot_ort(dynamic=False, rewrite=True)

    optimized_mod = torch.compile(
        copy.deepcopy(model), backend=local_aot_ort, dynamic=False, fullgraph=True
    )

    with dump_onnx("llama_onnxrt", folder=folder, clean=True):
        if use_mixed:
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                torch.cuda.synchronize()
                expected_onnxrt = train_loop(optimized_mod, *inputs[0])
                torch.cuda.synchronize()
        else:
            expected_onnxrt = train_loop(optimized_mod, *inputs[0])
    assert_all_close(expected[0], expected_onnxrt[0], atol=1e-3)
    print(
        f"-- onnxrt backend worked, {len(expected_onnxrt)} gradients, first one is "
        f"{expected_onnxrt[0].shape}, {expected_onnxrt[0].dtype}"
    )

    # debugging backend
    aot_compiler = aot_autograd(
        fw_compiler=lambda *args, **kwargs: onnx_debug_backend(
            *args,
            dump_prefix=os.path.join(folder, "llama_debug"),
            target_opset=opset,
            storage=storage,
            **kwargs,
        ),
        decompositions=get_decomposition_table(),
    )
    onnx_mod = torch.compile(copy.deepcopy(model), backend=aot_compiler, fullgraph=True)

    if use_mixed:
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            torch.cuda.synchronize()
            got = train_loop(onnx_mod, *inputs[0])
            torch.cuda.synchronize()
    else:
        got = train_loop(onnx_mod, *inputs[0])
    assert_all_close(expected[0], got[0], atol=1e-2 if use_mixed else 1e-4)
    print(
        f"-- debug backend worked, {len(got)} gradients, first one is "
        f"{got[0].shape}, {got[0].dtype}"
    )

else:
    # onnxrt backend
    local_aot_ort, _ = make_aot_ort(dynamic=True, rewrite=True)
    optimized_mod = torch.compile(model, backend=local_aot_ort, fullgraph=True)
    with dump_onnx("llama_onnxrt", folder=folder, clean=True):
        if use_mixed:
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                torch.cuda.synchronize()
                expected_onnxrt = optimized_mod(*inputs[0])
                torch.cuda.synchronize()
        else:
            expected_onnxrt = optimized_mod(*inputs[0])
    assert_all_close(expected, expected_onnxrt, atol=1e-2)

    # debugging backend
    aot_compiler = aot_autograd(
        fw_compiler=lambda *args, **kwargs: onnx_debug_backend(
            *args,
            dump_prefix=os.path.join(folder, "llama_debug"),
            target_opset=17,
            storage=storage,
            **kwargs,
        )
    )

    onnx_mod = torch.compile(model, backend=aot_compiler, fullgraph=True)
    if use_mixed:
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            got = onnx_mod(*inputs[0])
    else:
        got = onnx_mod(*inputs[0])
    assert_all_close(expected, got, atol=1 if use_mixed else 1e-3)
/home/xadupre/vv/this/lib/python3.10/site-packages/torch/onnx/_internal/_exporter_legacy.py:108: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
  warnings.warn(
/home/xadupre/vv/this/lib/python3.10/site-packages/torch/onnx/_internal/fx/onnxfunction_dispatcher.py:503: FutureWarning: 'onnxscript.values.TracedOnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
  self.param_schema = self.onnxfunction.param_schemas()
Applied 9 of general pattern rewrite rules.
Applied 1 of general pattern rewrite rules.

For forward, there are two files, one onnx model and the graph module printed in a txt file. For backward, there are two onnx models. Then it is multiplied by the number of backends.

models = os.listdir(folder)
print(f"exported models: {models}")
exported models: ['llama_debug_0.onnx', 'llama_onnxrt_0.txt', 'llama_debug_0.txt', 'llama_onnxrt_0.onnx']

Inputs used by the debug backend

feeds = storage["instance"][0]["inputs"][0]
for k, v in feeds.items():
    print(f"-- {k} {v.dtype} {v.shape}")
-- input0 float32 (2, 1024, 512)
-- input1 float32 (512, 512)
-- input2 float32 (512, 512)
-- input3 float32 (512, 512)
-- input4 float32 (32,)
-- input5 int64 (1, 1024)
-- input6 float32 (2, 1, 1024, 1024)
-- input7 float32 (512, 512)

Let’s the first line of the graph module

graph_module = storage["instance"][0]["graph_module"]
print("\n".join(str(graph_module.graph).split("\n")[:10]))
graph():
    %primals_1 : [num_users=3] = placeholder[target=primals_1]
    %primals_2 : [num_users=1] = placeholder[target=primals_2]
    %primals_3 : [num_users=1] = placeholder[target=primals_3]
    %primals_4 : [num_users=1] = placeholder[target=primals_4]
    %primals_5 : [num_users=1] = placeholder[target=primals_5]
    %primals_6 : [num_users=1] = placeholder[target=primals_6]
    %primals_7 : [num_users=1] = placeholder[target=primals_7]
    %primals_8 : [num_users=1] = placeholder[target=primals_8]
    %t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%primals_2,), kwargs = {})

Comparison and execution

if backward:
    print(f"-- {len(storage['instance'])} onnx models were creates")
    for i, inst in enumerate(storage["instance"]):
        print(f"  model {i}: {len(inst['inputs'])} runs")

    # deal with backward
    onnx_models = list(sorted([m for m in models if m.endswith(".onnx")]))
    assert len(onnx_models) == 4, f"unexpected value {onnx_models}"
    onnx_models = list(sorted([m for m in models if m.endswith(".onnx") and "_1" in m]))
    assert len(onnx_models) == 2, f"unexpected value {onnx_models}"
    model_onnxrt = os.path.join(folder, onnx_models[1])
    model_debug = os.path.join(folder, onnx_models[0])
else:
    onnx_models = list(sorted([m for m in models if m.endswith(".onnx")]))
    if len(onnx_models) == 2:
        model_onnxrt = os.path.join(folder, onnx_models[1])
        model_debug = os.path.join(folder, onnx_models[0])
    else:
        model_debug = os.path.join(folder, onnx_models[0])
        # the following error may appear:
        # Node type 'Rank' from domain 'pkg.onnxscript.torch_lib.common' is unknown
        print(f"One model is missing, onnx_models={onnx_models}")
        model_onnxrt = model_debug

print(f"model_onnxrt={model_onnxrt}")
print(f"model_debug={model_debug}")
model_onnxrt=dump_models/llama_onnxrt_0.onnx
model_debug=dump_models/llama_debug_0.onnx

The inputs of both models

onnxrt: [('INPUT', 'primals_2', 1, (512, 512)), ('INPUT', 'primals_1', 1, (2, 1024, 512)), ('INPUT', 'primals_3', 1, (512, 512)), ('INPUT', 'primals_4', 1, (512, 512)), ('INPUT', 'primals_5', 1, (32,)), ('INPUT', 'primals_6', 7, (1, 1024)), ('INPUT', 'primals_7', 1, (2, 1, 1024, 1024)), ('INPUT', 'primals_8', 1, (512, 512))]
debug: [('INPUT', 'input0', 1, (2, 1024, 512)), ('INPUT', 'input1', 1, (512, 512)), ('INPUT', 'input2', 1, (512, 512)), ('INPUT', 'input3', 1, (512, 512)), ('INPUT', 'input4', 1, (32,)), ('INPUT', 'input5', 7, (1, 1024)), ('INPUT', 'input6', 1, (2, 1, 1024, 1024)), ('INPUT', 'input7', 1, (512, 512))]

Inputs are not the same. The first model has more and some inputs were moved into the initializer list into for model_debug.

print("debug:", inputs_from_onnx_model(model_debug, init=True))
debug: [('INPUT', 'input0', 1, (2, 1024, 512)), ('INPUT', 'input1', 1, (512, 512)), ('INPUT', 'input2', 1, (512, 512)), ('INPUT', 'input3', 1, (512, 512)), ('INPUT', 'input4', 1, (32,)), ('INPUT', 'input5', 7, (1, 1024)), ('INPUT', 'input6', 1, (2, 1, 1024, 1024)), ('INPUT', 'input7', 1, (512, 512)), ('INIT', 'init7_s2_2048_512', 7, (2,)), ('INIT', 'init7_s3_2_1024_512', 7, (3,)), ('INIT', 'init7_s4_2_1024_8_64', 7, (4,)), ('INIT', 'init7_s1_1', 7, (1,)), ('INIT', 'init1_s_', 1, ()), ('INIT', 'init7_s3_16_1024_64', 7, (3,)), ('INIT', 'init7_s3_16_64_1024', 7, (3,)), ('INIT', 'init1_s_2', 1, ()), ('INIT', 'init7_s3_16_1024_1024', 7, (3,)), ('INIT', 'init7_s2_0_2', 7, (2,)), ('INIT', 'init7_s2_32_32', 7, (2,))]

Optimization and Verification

Let’s try the model with a python backend (reference implementation). First step, onnxscript uses many functions. The reference evaluation expects every function to be defined so the order of functions in the model matters. No recursivity is allowed by this runtime. We need to reorder as function Rank is usually placed at the end of the model.

'dump_models/llama_onnxrt_0.onnx'

Let’s load the model and optimize them.

debug = onnx.load(model_debug)
try:
    onnxrt = optimize_model_proto_oxs(onnx.load(model_onnxrt))
except ImportError as e:
    print("missing library", e)
    onnxrt = debug

Let’s apply onnxruntime optimization

if ortopt:
    providers = (
        [("CUDAExecutionProvider", {}), ("CPUExecutionProvider", {})]
        if use_cuda
        else ["CPUExecutionProvider"]
    )
    with open(model_onnxrt.replace(".onnx", ".before.opt.onnx"), "wb") as f:
        f.write(onnxrt.SerializeToString())
    print(f"run onnxruntime optimization on {model_onnxrt}")
    optimized = model_onnxrt.replace(".onnx", ".opt.onnx")
    ort_optimize(onnxrt, output=optimized, providers=providers)
    onnxrt = onnx.load(optimized)

    print(f"run onnxruntime optimization on {model_debug}")
    optimized = model_debug.replace(".onnx", ".opt.onnx")
    ort_optimize(debug, output=optimized, disable_aot=True, providers=providers)
    debug = onnx.load(optimized)
run onnxruntime optimization on dump_models/llama_onnxrt_0.onnx
run onnxruntime optimization on dump_models/llama_debug_0.onnx

For what’s following, we need to build two lists of matching inputs.

print("build_matching_inputs")
feedsrt = build_matching_inputs(model_debug, feeds, model_onnxrt)
print("done")
build_matching_inputs
done

We check both models are running.

out_onnxrt = ExtendedReferenceEvaluator(onnxrt).run(None, feedsrt)
out_debug = ExtendedReferenceEvaluator(debug).run(None, feeds)
assert out_onnxrt
assert out_debug

# assert_all_close(out_onnxrt, out_debug)

Side by side

res1, res2, align, dc = compare_onnx_execution(
    onnxrt,
    debug,
    verbose=1,
    raise_exc=True,
    inputs=(feedsrt, feeds),
)
text = dc.to_str(res1, res2, align, column_size=90)
print(text)
[compare_onnx_execution] execute with 2 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 94 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 94 results (first model)
[compare_onnx_execution] got 82 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 105 pairs
[compare_onnx_execution] done
001 = | INITIA int64    1:2                  USAA                 _val_301                         | INITIA int64    1:2                  USAA                 init7_s2_2048_512
002 - | INITIA int64                         AAAA                 aten_unsqueeze_75_dim_0          |
003 - | INITIA int64    1:4                  CIKK                 _val_274                         |
004 - | INITIA int64                         BAAA                 aten_unsqueeze_311_dim_0         |
005 - | INITIA int64    1:3                  QMKA                 _val_269                         |
006 = | INITIA int64    1:3                  CKSA                 _val_96                          | INITIA int64    1:3                  CKSA                 init7_s3_2_1024_512
007 - | INITIA int64                         CAAA                 aten_unsqueeze_159_dim_0         |
008 ~ | INITIA int64    1:3                  QKMA                 _val_264                         | INITIA int64    1:4                  CKIM                 init7_s4_2_1024_8_64
009 ~ | INITIA int64    1:4                  CKIM                 _val_137                         | INITIA int64    1:1                  BAAA                 init7_s1_1
010 ~ | INITIA int64    1:2                  GGAA                 splits                           | INITIA int64    1:2                  ACAA                 init7_s2_0_2
011 - | INITIA float32                       IAAA                 _val_276                         |
012 ~ | INITIA int64    1:3                  CKZA                 _val_298                         | INITIA int64    1:3                  QKMA                 init7_s3_16_1024_64
013 ~ | INITIA int64    1:3                  QKKA                 _val_287                         | INITIA int64    1:3                  QMKA                 init7_s3_16_64_1024
014 = | INITIA int64    1:2                  GGAA                 splits_token_9                   | INITIA int64    1:2                  GGAA                 init7_s2_32_32
015 ~ | INITIA int64    1:4                  CIKM                 _val_293                         | INITIA int64    1:3                  QKKA                 init7_s3_16_1024_1024
016 + |                                                                                            | INPUT  float32  3:2x1024x512         RUEB                 input0
017 = | INPUT  float32  2:512x512            CKDZ                 primals_2                        | INPUT  float32  2:512x512            CKDZ                 input1
018 - | INPUT  float32  3:2x1024x512         RUEB                 primals_1                        |
019 = | INPUT  float32  2:512x512            RHAY                 primals_3                        | INPUT  float32  2:512x512            RHAY                 input2
020 = | INPUT  float32  2:512x512            IWCH                 primals_4                        | INPUT  float32  2:512x512            IWCH                 input3
021 = | INPUT  float32  1:32                 DAAA                 primals_5                        | INPUT  float32  1:32                 DAAA                 input4
022 = | INPUT  int64    2:1x1024             KAQG                 primals_6                        | INPUT  int64    2:1x1024             KAQG                 input5
023 = | INPUT  float32  4:2x1x1024x1024      AAAA                 primals_7                        | INPUT  float32  4:2x1x1024x1024      AAAA                 input6
024 = | INPUT  float32  2:512x512            YFSV                 primals_8                        | INPUT  float32  2:512x512            YFSV                 input7
025 - | RESULT float32  2:512x512            YFSV Identity        t_6                              |
026 = | RESULT int64    3:1x1x1024           KAQG Unsqueeze       unsqueeze_2                      | RESULT int64    3:1x1x1024           KAQG Unsqueeze       unsqueeze_2
027 = | RESULT float32  3:1x1x1024           KAQG Cast            _to_copy                         | RESULT float32  3:1x1x1024           KAQG Cast            _to_copy
028 - | RESULT float32  2:1x32               DAAA Unsqueeze       unsqueeze                        |
029 = | RESULT float32  3:1x32x1             DAAA Unsqueeze       unsqueeze_1                      | RESULT float32  3:1x32x1             DAAA Unsqueeze       unsqueeze_1
030 + |                                                                                            | RESULT float32  3:1x32x1024          EFXM MatMul          bmm
031 - | RESULT float32  3:1x1024x32          XCHM FusedMatMul     transpose_3                      |
032 - | RESULT float32  3:1x1024x64          VFPY Concat          cat                              |
033 ~ | RESULT float32  3:1x1024x64          GSEC Sin             sin                              | RESULT float32  3:1x64x1024          JKJK Concat          cat_token_5
034 ~ | RESULT float32  4:1x1x1024x64        GSEC Unsqueeze       unsqueeze_4                      | RESULT float32  3:1x64x1024          RMRM Sin             sin_token_7
035 + |                                                                                            | RESULT float32  4:1x1x64x1024        RMRM Unsqueeze       Opset8
036 = | RESULT float32  4:1x1024x1x64        GSEC Transpose       Transpose_token_4_out0           | RESULT float32  4:1x1024x1x64        GSEC Transpose       Transpose_token_10_out0
037 = | RESULT float32  2:2048x512           RUEB Reshape         view                             | RESULT float32  2:2048x512           RUEB Reshape         output_2
038 ~ | RESULT float32  2:2048x512           XSDJ FusedMatMul     mm_1                             | RESULT float32  2:2048x512           XSDJ Gemm            mm_1
039 - | RESULT float32  3:2x1024x512         XSDJ Reshape         _unsafe_view_1                   |
040 = | RESULT float32  4:2x1024x8x64        XSDJ Reshape         view_4                           | RESULT float32  4:2x1024x8x64        XSDJ Reshape         view_4
041 = | RESULT float32  4:2x1024x8x32        YHIS Split           Slice_263                        | RESULT float32  4:2x1024x8x32        YHIS Split           SlicesSplitPattern--slice_Tensor
042 = | RESULT float32  4:2x1024x8x32        ZLWS Split           Slice_280                        | RESULT float32  4:2x1024x8x32        ZLWS Split           SlicesSplitPattern--slice_Tensor
043 = | RESULT float32  4:2x1024x8x32        BPEI Neg             aten_neg_290_n0                  | RESULT float32  4:2x1024x8x32        BPEI Neg             neg2
044 = | RESULT float32  4:2x1024x8x64        AWMZ Concat          Concat_294                       | RESULT float32  4:2x1024x8x64        AWMZ Concat          cat3
045 = | RESULT float32  4:2x1024x8x64        IOVZ Mul             Mul_315                          | RESULT float32  4:2x1024x8x64        IOVZ Mul             mul_Tensor10
046 + |                                                                                            | RESULT float32  3:1x64x1024          NHNH Cos             cos_token_13
047 ~ | RESULT float32  3:1x1024x64          CJYF Cos             cos                              | RESULT float32  4:1x1x64x1024        NHNH Unsqueeze       Opset7
048 - | RESULT float32  4:1x1x1024x64        CJYF Unsqueeze       unsqueeze_3                      |
049 = | RESULT float32  4:1x1024x1x64        CJYF Transpose       Transpose_token_6_out0           | RESULT float32  4:1x1024x1x64        CJYF Transpose       Transpose_token_16_out0
050 = | RESULT float32  4:2x1024x8x64        NPOM Mul             Mul_313                          | RESULT float32  4:2x1024x8x64        NPOM Mul             mul_Tensor9
051 = | RESULT float32  4:2x1024x8x64        VDKK Add             Add_317                          | RESULT float32  4:2x1024x8x64        VDKK Add             add_Tensor2
052 = | RESULT float32  4:2x8x64x1024        GSVA Transpose       transpose_4                      | RESULT float32  4:2x8x64x1024        GSVA Transpose       transpose_4
053 + |                                                                                            | RESULT float32  4:1x1x1024x64        GSEC Transpose       output_5
054 ~ | RESULT float32  3:16x64x1024         GSVA Reshape         _unsafe_view_4                   | RESULT float32  2:2048x512           RUEB Reshape         output_1
055 ~ | RESULT float32  2:2048x512           AKZF FusedMatMul     mm                               | RESULT float32  2:2048x512           AKZF Gemm            mm
056 - | RESULT float32  3:2x1024x512         AKZF Reshape         _unsafe_view                     |
057 = | RESULT float32  4:2x1024x8x64        AKZF Reshape         view_3                           | RESULT float32  4:2x1024x8x64        AKZF Reshape         view_3
058 = | RESULT float32  4:2x8x1024x64        MXTM Transpose       transpose                        | RESULT float32  4:2x8x1024x64        MXTM Transpose       transpose
059 = | RESULT float32  4:2x8x1024x32        EYII Split           slice_4                          | RESULT float32  4:2x8x1024x32        EYII Split           slice_4
060 = | RESULT float32  4:2x8x1024x32        IZLF Split           slice_5                          | RESULT float32  4:2x8x1024x32        IZLF Split           slice_5
061 = | RESULT float32  4:2x8x1024x32        SBPV Neg             neg                              | RESULT float32  4:2x8x1024x32        SBPV Neg             neg
062 = | RESULT float32  4:2x8x1024x64        WYXD Concat          cat_1                            | RESULT float32  4:2x8x1024x64        WYXD Concat          cat_1
063 = | RESULT float32  4:2x8x1024x64        UPSA Mul             mul_3                            | RESULT float32  4:2x8x1024x64        UPSA Mul             mul_3
064 + |                                                                                            | RESULT float32  4:1x1x1024x64        CJYF Transpose       output_4
065 = | RESULT float32  4:2x8x1024x64        ZZUM Mul             mul_2                            | RESULT float32  4:2x8x1024x64        ZZUM Mul             mul_2
066 = | RESULT float32  4:2x8x1024x64        SPNL Add             add                              | RESULT float32  4:2x8x1024x64        SPNL Add             add
067 - | RESULT float32  3:16x1024x64         SPNL Reshape         _unsafe_view_3                   |
068 - | RESULT float32  3:16x1024x1024       NBIS MatMul          bmm_1                            |
069 - | RESULT float32  4:2x8x1024x1024      NBIS Reshape         view_9                           |
070 ~ | RESULT float32  4:2x8x1024x1024      FQOI Div             div                              | RESULT float32  4:2x8x1024x1024      FQOI FusedMatMul     div
071 = | RESULT float32  4:2x8x1024x1024      FQOI Add             add_2                            | RESULT float32  4:2x8x1024x1024      FQOI Add             add_2
072 = | RESULT float32  4:2x8x1024x1024      NNNN Softmax         _softmax                         | RESULT float32  4:2x8x1024x1024      NNNN Softmax         output_8
073 - | RESULT float32  3:16x1024x1024       NNNN Reshape         view_10                          |
074 ~ | RESULT float32  2:2048x512           ENDH FusedMatMul     mm_2                             | RESULT float32  2:2048x512           RUEB Reshape         output_3
075 ~ | RESULT float32  3:2x1024x512         ENDH Reshape         _unsafe_view_2                   | RESULT float32  2:2048x512           ENDH Gemm            mm_2
076 = | RESULT float32  4:2x1024x8x64        ENDH Reshape         view_5                           | RESULT float32  4:2x1024x8x64        ENDH Reshape         view_5
077 = | RESULT float32  4:2x8x1024x64        HLUQ Transpose       transpose_2                      | RESULT float32  4:2x8x1024x64        HLUQ Transpose       transpose_2
078 ~ | RESULT float32  3:16x1024x64         HLUQ Reshape         _unsafe_view_5                   | RESULT float32  4:2x8x1024x64        FQYP MatMul          view_11
079 ~ | RESULT float32  3:16x1024x64         FQYP MatMul          bmm_2                            | RESULT float32  4:2x1024x8x64        LKXS Transpose       transpose_5
080 ~ | RESULT float32  4:2x8x1024x64        FQYP Reshape         view_11                          | RESULT float32  2:2048x512           LKXS Reshape         output_12
081 ~ | RESULT float32  4:2x1024x8x64        LKXS Transpose       transpose_5                      | RESULT float32  2:2048x512           OOUR Gemm            mm_3
082 ~ | RESULT float32  3:2x1024x512         LKXS Reshape         view_12                          | RESULT float32  3:2x1024x512         OOUR Reshape         output_0
083 + |                                                                                            | RESULT float32  3:16x1024x1024       NNNN Reshape         output_9
084 ~ | RESULT float32  2:2048x512           LKXS Reshape         view_13                          | RESULT float32  3:16x64x1024         GSVA Reshape         output_7
085 ~ | RESULT float32  2:2048x512           OOUR FusedMatMul     mm_3                             | RESULT float32  3:16x1024x64         SPNL Reshape         output_6
086 ~ | RESULT float32  3:2x1024x512         OOUR Reshape         _unsafe_view_6                   | RESULT float32  3:16x1024x64         HLUQ Reshape         output_10
087 + |                                                                                            | RESULT float32  2:512x512            CFZI Transpose       output_11
088 - | RESULT float32  3:16x1024x1024       NNNN Transpose       transpose_7                      |
089 - | RESULT float32  4:2x8x1024x1024      NNNN Identity        detach_3                         |
090 ~ | RESULT float32  3:16x1024x64         GSVA Transpose       transpose_10                     | OUTPUT float32  3:2x1024x512         OOUR                 output_0
091 ~ | RESULT float32  3:16x64x1024         SPNL Transpose       transpose_9                      | OUTPUT float32  2:2048x512           RUEB                 output_1
092 - | RESULT float32  3:16x64x1024         HLUQ Transpose       transpose_8                      |
093 = | OUTPUT float32  2:2048x512           RUEB                 view                             | OUTPUT float32  2:2048x512           RUEB                 output_2
094 - | OUTPUT float32  2:512x512            YFSV                 t_6                              |
095 ~ | OUTPUT float32  3:16x64x1024         HLUQ                 transpose_8                      | OUTPUT float32  2:2048x512           RUEB                 output_3
096 ~ | OUTPUT float32  3:1x1024x64          VFPY                 cat                              | OUTPUT float32  4:1x1x1024x64        CJYF                 output_4
097 + |                                                                                            | OUTPUT float32  4:1x1x1024x64        GSEC                 output_5
098 + |                                                                                            | OUTPUT float32  3:16x1024x64         SPNL                 output_6
099 ~ | OUTPUT float32  3:16x64x1024         SPNL                 transpose_9                      | OUTPUT float32  3:16x64x1024         GSVA                 output_7
100 - | OUTPUT float32  3:16x1024x64         GSVA                 transpose_10                     |
101 = | OUTPUT float32  4:2x8x1024x1024      NNNN                 detach_3                         | OUTPUT float32  4:2x8x1024x1024      NNNN                 output_8
102 = | OUTPUT float32  3:16x1024x1024       NNNN                 transpose_7                      | OUTPUT float32  3:16x1024x1024       NNNN                 output_9
103 ~ | OUTPUT float32  2:2048x512           LKXS                 view_13                          | OUTPUT float32  3:16x1024x64         HLUQ                 output_10
104 + |                                                                                            | OUTPUT float32  2:512x512            CFZI                 output_11
105 ~ | OUTPUT float32  3:2x1024x512         OOUR                 _unsafe_view_6                   | OUTPUT float32  2:2048x512           LKXS                 output_12

Total running time of the script: (0 minutes 5.840 seconds)

Gallery generated by Sphinx-Gallery