301: Compares LLAMA exporters

The script compares the two exporters implemented in pytorch for a part of llama model. The model are compared after all optimizations were made with and onnxruntime.

To run the script:

python _doc/examples/plot_llama_diff_export --help

Some helpers

from experimental_experiment.args import get_parsed_args

script_args = get_parsed_args(
    "plot_llama_diff_export",
    description=__doc__,
    part=("model", "one value among model, ..."),
    exporter=("dynamo", "one value among dynamo, custom"),
    ortopt=(1, "run onnxruntime optimization"),
    opset=(18, "onnx opset"),
    expose="part,exporter,ortopt,opset",
)

import contextlib
import os
import io
import warnings
import logging

try:
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        import onnxruntime

        has_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers()
except ImportError:
    print("onnxruntime not available.")
    import sys

    sys.exit(0)

import numpy as np
import onnx
from onnx_array_api.reference import compare_onnx_execution
import torch
from experimental_experiment.ext_test_case import unit_test_going
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.helpers import string_type
from experimental_experiment.xbuilder import OptimizationOptions
from experimental_experiment.convert.convert_helper import ort_optimize
from experimental_experiment.torch_models.llama_helper import get_llama_model
from experimental_experiment.torch_models.dump_helper import reorder_functions_in_proto

has_cuda = has_cuda and torch.cuda.device_count() > 0
logging.disable(logging.ERROR)
provider = "cuda" if has_cuda else "cpu"

The exporting functions

print(f"part={script_args.part}")
print(f"exporter={script_args.exporter}")
ortopt = script_args.ortopt in (1, "1")
print(f"ortopt={ortopt}")
opset = int(script_args.opset)
print(f"opset={opset}")


def opt_filename(filename: str) -> str:
    name, ext = os.path.splitext(filename)
    return f"{name}.opt{ext}"


def export_script(filename, model, *args):
    with contextlib.redirect_stdout(io.StringIO()):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            torch.onnx.export(
                model, args, filename, input_names=["input"], opset_version=opset
            )
    if ortopt:
        onx = onnx.load(filename)
        ort_optimize(onx, opt_filename(filename), providers=provider)


def export_dynamo(filename, model, *args):
    with contextlib.redirect_stdout(io.StringIO()):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            export_output = torch.onnx.export(model, args, dynamo=True)
            export_output.optimize()
            model = export_output.model_proto
    with open(filename, "wb") as f:
        f.write(model.SerializeToString())
    if ortopt:
        ort_optimize(model, opt_filename(filename), providers=provider)


def export_custom(filename, model, *args):
    new_model = to_onnx(
        model,
        tuple(args),
        input_names=[f"input{i}" for i in range(len(args))],
        options=OptimizationOptions(
            remove_unused=True,
            constant_folding=False,
        ),
        target_opset=opset,
    )
    with open(filename, "wb") as f:
        f.write(new_model.SerializeToString())
    if ortopt:
        ort_optimize(new_model, opt_filename(filename), providers=provider)
part=model
exporter=dynamo
ortopt=True
opset=18

Model and data

if unit_test_going():
    kwargs = dict(input_dims=[(2, 1024)] * 2)
else:
    kwargs = dict(
        input_dims=[(2, 1024)] * 2,
        _attn_implementation="eager",
        num_hidden_layers=1,
        hidden_size=512,
        vocab_size=4000,
        intermediate_size=2000,
        max_position_embeddings=2048,
        num_attention_heads=8,
    )

if script_args.part == "model":
    model, inputs = get_llama_model(**kwargs)
else:
    raise RuntimeError(f"Unexpected value for part={script_args.part!r}")

print(f"simple run with {len(inputs)} inputs")
expected = model(*inputs[0])
print(f"eager worked: {string_type(expected, with_shape=True)}")
simple run with 2 inputs
eager worked: (T1s2x1024x512,)

Exporting

exporter = script_args.exporter
file1 = f"llama.{script_args.part}.script.onnx"
file2 = f"llama.{script_args.part}.{exporter}.onnx"

print("torch script exporter")
export_script(file1, model, *inputs[0])

if exporter == "dynamo":
    print("torch dynamo exporter")
    export_dynamo(file2, model, *inputs[0])
elif exporter == "custom":
    print("torch custom exporter")
    export_custom(file2, model, *inputs[0])
else:
    raise AssertionError(f"Unexpected value for exporter={exporter!r}.")
torch script exporter
torch dynamo exporter

Verification

if ortopt:
    print("Using models optimized by onnxruntime")
    file1 = f"llama.{script_args.part}.script.opt.onnx"
    file2 = f"llama.{script_args.part}.{exporter}.opt.onnx"


providers = (
    ["CPUExecutionProvider"]
    if provider == "cpu"
    else [("CUDAExecutionProvider", {}), ("CPUExecutionProvider", {})]
)

model1 = onnx.load(file1)
model2 = onnx.load(file2)

feeds1, feeds2 = {}, {}
for i in range(len(inputs[0])):
    x = inputs[0][i].detach().numpy()
    feeds1[model1.graph.input[i].name] = x
    feeds2[model2.graph.input[i].name] = x

if ortopt:
    sess1 = onnxruntime.InferenceSession(file1, providers=providers)
    sess2 = onnxruntime.InferenceSession(file2, providers=providers)

    got1 = sess1.run(None, feeds1)
    got2 = sess2.run(None, feeds2)

    if isinstance(expected, tuple) and len(expected) == 1:
        expected = expected[0]
    diff1 = np.abs(expected.detach().numpy() - got1[0]).max()
    diff2 = np.abs(expected.detach().numpy() - got2[0]).max()

    print(f"Error with the eager model and onnxruntime: {diff1}, {diff2}")
Using models optimized by onnxruntime
Error with the eager model and onnxruntime: 0.00330527126789093, 0.00330527126789093

Verification with the reference evaluator

reorder_functions_in_proto(file1)
reorder_functions_in_proto(file2)

sess1 = ExtendedReferenceEvaluator(file1)
try:
    sess2 = ExtendedReferenceEvaluator(file2)
except NotImplementedError as e:
    print(e)
    sess2 = None

got1 = sess1.run(None, feeds1)
got2 = got1 if sess2 is None else sess2.run(None, feeds2)

if isinstance(expected, tuple):
    diff1 = np.abs(expected[0].detach().numpy() - got1[0]).max()
    diff2 = np.abs(expected[0].detach().numpy() - got2[0]).max()
else:
    diff1 = np.abs(expected.detach().numpy() - got1[0]).max()
    diff2 = np.abs(expected.detach().numpy() - got2[0]).max()

print(f"Error with the eager model and the reference evaluator: {diff1}, {diff2}")
Error with the eager model and the reference evaluator: 3.0994415283203125e-06, 3.516674041748047e-06

Comparison and execution

def clean_name(name):
    return name.replace(
        "_inlfunc_transformers_models_llama_modeling_llama_LlamaAttention", ""
    ).replace("_inlfunc_torch_nn_modules_linear_Linear", "")


if sess2 is not None:
    try:
        np_inputs = [i.detach().numpy() for i in inputs[0]]
        res1, res2, align, dc = compare_onnx_execution(
            model1,
            model2,
            inputs=np_inputs,
            verbose=1,
            raise_exc=False,
            cls=ExtendedReferenceEvaluator,
        )
        for r in res2:
            r.name = clean_name(r.name)
        text = dc.to_str(res1, res2, align, column_size=90)
        print(text)
    except AssertionError as e:
        if "Unexpected type <class 'list'> for value, it must be a numpy array." not in str(e):
            raise
        print(e)
[compare_onnx_execution] execute with 2 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 83 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 83 results (first model)
[compare_onnx_execution] got 99 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 112 pairs
[compare_onnx_execution] done
001 = | INITIA float32  2:4000x512           WFSF                 model.embed_tokens.weight        | INITIA float32  2:4000x512           WFSF                 model.embed_tokens.weight
002 + |                                                                                            | INITIA float32  4:2x1x1024x1024      ????                 expand_1
003 + |                                                                                            | INITIA float32  4:1x1024x1x64        CJYF                 unsqueeze_10
004 - | INITIA float32  1:512                YYYY                 model.layers.0.input_layernorm.w |
005 - | INITIA float32  2:512x512            HUOE                 onnx::MatMul_381                 |
006 - | INITIA float32  2:512x512            YABW                 onnx::MatMul_397                 |
007 - | INITIA float32  2:512x512            BXHZ                 onnx::MatMul_398                 |
008 = | INITIA float32  2:512x512            AZGY                 onnx::MatMul_423                 | INITIA float32  2:512x512            AZGY                 val_321
009 - | INITIA float32  2:512x2000           JWYN                 onnx::MatMul_424                 |
010 = | INITIA float32  2:512x2000           UVSU                 onnx::MatMul_425                 | INITIA float32  2:512x2000           UVSU                 val_327
011 + |                                                                                            | INITIA float32  2:512x512            HUOE                 val_239
012 - | INITIA float32  2:2000x512           YVOS                 onnx::MatMul_426                 |
013 - | INITIA int64    5:2x1x1024x1024x4    AQYO                 /model/Concat_output_0           |
014 ~ | INITIA int64    1:4                  CBKK                 /model/Concat_1_output_0         | INITIA float32                       AAAA                 val_237
015 - | INITIA float32  4:1x1024x1x64        GSEC                 /model/layers.0/self_attn/Unsque |
016 ~ | INITIA int64    1:2                  GGAA                 splits                           | INITIA int64    1:1                  ZAAA                 val_323
017 + |                                                                                            | INITIA float32  1:512                YYYY                 model.layers.0.input_layernorm.w
018 + |                                                                                            | INITIA float32  1:512                YYYY                 model.layers.0.post_attention_la
019 ~ | INITIA float32                       ?AAA                 /model/Constant_19_output_0      | INITIA float32  1:512                YYYY                 model.norm.weight
020 - | INITIA float32  4:2x1x1024x1024      ????                 /model/Slice_2_output_0          |
021 - | INITIA float32  4:1x1024x1x64        CJYF                 /model/layers.0/self_attn/Unsque |
022 ~ | INITIA int64    1:4                  CKZM                 /model/layers.0/self_attn/Consta | INITIA int64    1:2                  GGAA                 splits_token_10
023 + |                                                                                            | INITIA float32  2:2000x512           YVOS                 val_328
024 ~ | INITIA int64    1:2                  GGAA                 splits_token_12                  | INITIA float32                       ?AAA                 val_5
025 + |                                                                                            | INITIA float32  2:512x2000           JWYN                 val_325
026 ~ | INITIA int64    1:1                  AAAA                 /model/layers.0/self_attn/Consta | INITIA int64                         BAAA                 dim_0_10
027 + |                                                                                            | INITIA float32  2:512x512            YABW                 val_242
028 + |                                                                                            | INITIA float32  4:1024x1x2x1024      ????                 val_180
029 ~ | INITIA int64    1:1                  BAAA                 /model/layers.0/self_attn/Consta | INITIA int64    1:3                  CKZA                 val_320
030 + |                                                                                            | INITIA float32  2:512x512            BXHZ                 val_244
031 ~ | INITIA int64    1:1                  KAAA                 /model/layers.0/self_attn/Consta | INITIA int64                         CAAA                 val_20
032 + |                                                                                            | INITIA int64    2:1024x1             KAQG                 val_178
033 ~ | INITIA int64    1:1                  DAAA                 const_transpose_optimizer_token_ | INITIA int64    1:2                  GGAA                 splits
034 ~ | INITIA int64    1:1                  CAAA                 /model/Constant_13_output_0      | INITIA int64    1:4                  CKZM                 val_243
035 = | INITIA float32                       AAAA                 /model/Constant_14_output_0      | INITIA float32                       AAAA                 scalar_tensor_default
036 - | INITIA float32  4:2x1x1024x1024      ????                 /model/Expand_output_0           |
037 = | INITIA float32  4:1x1x1024x64        GSEC                 Transpose_token_4_out0           | INITIA float32  4:1x1x1024x64        GSEC                 Transpose_token_4_out0
038 + |                                                                                            | INITIA float32  4:1x1024x1x64        GSEC                 unsqueeze_11
039 ~ | INITIA int64    1:3                  CKZA                 /model/layers.0/self_attn/Consta | INITIA int64                         CAAA                 val_20_token_11
040 = | INPUT  int64    2:2x1024             NUZC                 input                            | INPUT  int64    2:2x1024             NUZC                 input_ids
041 = | INPUT  float32  2:2x1024             BACA                 attention_mask.1                 | INPUT  float32  2:2x1024             BACA                 attention_mask
042 = | RESULT float32  3:2x1024x512         LSTN Gather          /model/embed_tokens/Gather_outpu | RESULT float32  3:2x1024x512         LSTN Gather          embedding
043 ~ | RESULT float32  3:2x1024x512         FGKB SimplifiedLayer /model/layers.0/input_layernorm/ | RESULT float32  3:2x1024x512         ABAA Pow             pow_1
044 + |                                                                                            | RESULT float32  3:2x1024x1           AAAA ReduceMean      mean
045 + |                                                                                            | RESULT float32  3:2x1024x1           AAAA Add             add_1
046 + |                                                                                            | RESULT float32  3:2x1024x1           KKKK Sqrt            val_238
047 ~ | RESULT float32  3:2x1024x1           LVSZ SimplifiedLayer saved_inv_std_var                | RESULT float32  3:2x1024x1           LVSZ Reciprocal      rsqrt
048 + |                                                                                            | RESULT float32  3:2x1024x512         FGKB Mul             mul_3
049 + |                                                                                            | RESULT float32  3:2x1024x512         FGKB Mul             mul_4
050 = | RESULT float32  3:2x1024x512         PEXC MatMul          /model/layers.0/self_attn/k_proj | RESULT float32  3:2x1024x512         PEXC MatMul          linear_1
051 = | RESULT float32  4:2x1024x8x64        PEXC Reshape         /model/layers.0/self_attn/Reshap | RESULT float32  4:2x1024x8x64        PEXC Reshape         view_2
052 = | RESULT float32  4:2x1024x8x32        HTLG Split           /model/layers.0/self_attn/Slice_ | RESULT float32  4:2x1024x8x32        HTLG Split           node_Slice_363
053 = | RESULT float32  4:2x1024x8x32        ILLW Split           /model/layers.0/self_attn/Slice_ | RESULT float32  4:2x1024x8x32        ILLW Split           node_Slice_374
054 = | RESULT float32  4:2x1024x8x32        SPPE Neg             /model/layers.0/self_attn/Neg_1  | RESULT float32  4:2x1024x8x32        SPPE Neg             node_Neg_375
055 = | RESULT float32  4:2x1024x8x64        ZIAK Concat          /model/layers.0/self_attn/Concat | RESULT float32  4:2x1024x8x64        ZIAK Concat          node_Concat_376
056 = | RESULT float32  4:2x1024x8x64        TEUR Mul             /model/layers.0/self_attn/Mul_3  | RESULT float32  4:2x1024x8x64        TEUR Mul             node_Mul_377
057 = | RESULT float32  4:2x1024x8x64        MQHR Mul             /model/layers.0/self_attn/Mul_2  | RESULT float32  4:2x1024x8x64        MQHR Mul             node_Mul_352
058 = | RESULT float32  4:2x1024x8x64        GTBH Add             /model/layers.0/self_attn/Add_1  | RESULT float32  4:2x1024x8x64        GTBH Add             node_Add_378
059 = | RESULT float32  4:2x8x64x1024        NMTR Transpose       /model/layers.0/self_attn/Transp | RESULT float32  4:2x8x64x1024        NMTR Transpose       transpose_4
060 = | RESULT float32  3:2x1024x512         GFHI MatMul          /model/layers.0/self_attn/q_proj | RESULT float32  3:2x1024x512         GFHI MatMul          linear
061 = | RESULT float32  4:2x1024x8x64        GFHI Reshape         /model/layers.0/self_attn/Reshap | RESULT float32  4:2x1024x8x64        GFHI Reshape         view_1
062 = | RESULT float32  4:2x1024x8x64        RUMP Mul             /model/layers.0/self_attn/Mul    | RESULT float32  4:2x1024x8x64        RUMP Mul             node_Mul_324
063 = | RESULT float32  4:2x8x1024x64        XMRL Transpose       /model/layers.0/self_attn/Mul_ou | RESULT float32  4:2x8x1024x64        XMRL Transpose       mul_5
064 = | RESULT float32  4:2x8x1024x64        STND Transpose       /model/layers.0/self_attn/Transp | RESULT float32  4:2x8x1024x64        STND Transpose       transpose_1
065 = | RESULT float32  4:2x8x1024x32        TFVN Split           /model/layers.0/self_attn/Slice_ | RESULT float32  4:2x8x1024x32        TFVN Split           slice_24
066 = | RESULT float32  4:2x8x1024x32        ZOSQ Split           /model/layers.0/self_attn/Slice_ | RESULT float32  4:2x8x1024x32        ZOSQ Split           slice_25
067 = | RESULT float32  4:2x8x1024x32        BMIK Neg             /model/layers.0/self_attn/Neg_ou | RESULT float32  4:2x8x1024x32        BMIK Neg             neg
068 = | RESULT float32  4:2x8x1024x64        VSCX Concat          /model/layers.0/self_attn/Concat | RESULT float32  4:2x8x1024x64        VSCX Concat          cat_1
069 = | RESULT float32  4:2x8x1024x64        CXIF Mul             /model/layers.0/self_attn/Mul_1_ | RESULT float32  4:2x8x1024x64        CXIF Mul             mul_6
070 = | RESULT float32  4:2x8x1024x64        ZIYR Add             /model/layers.0/self_attn/Add_ou | RESULT float32  4:2x8x1024x64        ZIYR Add             add_2
071 = | RESULT float32  4:2x8x1024x1024      NEIM FusedMatMul     /model/layers.0/self_attn/Mul_4_ | RESULT float32  4:2x8x1024x1024      NEIM FusedMatMul     mul_9
072 = | RESULT float32  3:2x1x1024           BACA Unsqueeze       /model/Unsqueeze_2_output_0      | RESULT float32  3:2x1x1024           BACA Unsqueeze       unsqueeze_5
073 = | RESULT float32  4:2x1x1x1024         BACA Unsqueeze       /model/Unsqueeze_3_output_0      | RESULT float32  4:2x1x1x1024         BACA Unsqueeze       unsqueeze_6
074 = | RESULT float32  4:2x1x1024x1024      ???? Add             /model/Add_output_0              | RESULT float32  4:2x1x1024x1024      ???? Add             add
075 = | RESULT bool     4:2x1x1024x1024      KWTE Equal           /model/Equal_1_output_0          | RESULT bool     4:2x1x1024x1024      KWTE Equal           eq
076 = | RESULT float32  4:2x1x1024x1024      ???? Where           /model/Where_1_output_0          | RESULT float32  4:2x1x1024x1024      ???? Where           masked_fill
077 + |                                                                                            | RESULT float32  4:1024x1x2x1024      ???? Transpose       val_179
078 + |                                                                                            | RESULT float32  4:1024x1x2x1024      ???? ScatterND       val_181
079 - | RESULT float32  4:2x1x1024x1024      ???? Reshape         /model/Reshape_output_0          |
080 - | RESULT float32  4:2x1x1024x1024      ???? ScatterND       /model/ScatterND_output_0        |
081 ~ | RESULT float32  4:2x1x1024x1024      ???? Slice           /model/layers.0/self_attn/Slice_ | RESULT float32  4:2x1x1024x1024      ???? Transpose       slice_scatter_1
082 = | RESULT float32  4:2x8x1024x1024      ???? Add             /model/layers.0/self_attn/Add_2_ | RESULT float32  4:2x8x1024x1024      ???? Add             add_4
083 = | RESULT float32  4:2x8x1024x1024      OOON Softmax         /model/layers.0/self_attn/Softma | RESULT float32  4:2x8x1024x1024      OOON Softmax         val_318
084 = | RESULT float32  3:2x1024x512         LJKD MatMul          /model/layers.0/self_attn/v_proj | RESULT float32  3:2x1024x512         LJKD MatMul          linear_2
085 = | RESULT float32  4:2x1024x8x64        LJKD Reshape         /model/layers.0/self_attn/Reshap | RESULT float32  4:2x1024x8x64        LJKD Reshape         view_3
086 = | RESULT float32  4:2x8x1024x64        NIJF Transpose       /model/layers.0/self_attn/Transp | RESULT float32  4:2x8x1024x64        NIJF Transpose       transpose_3
087 = | RESULT float32  4:2x8x1024x64        PQLS MatMul          /model/layers.0/self_attn/MatMul | RESULT float32  4:2x8x1024x64        PQLS MatMul          matmul_2
088 = | RESULT float32  4:2x1024x8x64        PPVG Transpose       /model/layers.0/self_attn/Transp | RESULT float32  4:2x1024x8x64        PPVG Transpose       transpose_5
089 = | RESULT float32  3:2x1024x512         PPVG Reshape         /model/layers.0/self_attn/Reshap | RESULT float32  3:2x1024x512         PPVG Reshape         view_4
090 = | RESULT float32  3:2x1024x512         FFCQ MatMul          /model/layers.0/self_attn/o_proj | RESULT float32  3:2x1024x512         FFCQ MatMul          linear_3
091 = | RESULT float32  3:2x1024x512         QXUD Add             /model/layers.0/Add_output_0     | RESULT float32  3:2x1024x512         QXUD Add             add_5
092 ~ | RESULT float32  3:2x1024x512         NTCA SimplifiedLayer /model/layers.0/post_attention_l | RESULT float32  3:2x1024x512         TQKA Pow             pow_2
093 + |                                                                                            | RESULT float32  3:2x1024x1           YYKK ReduceMean      mean_1
094 + |                                                                                            | RESULT float32  3:2x1024x1           YYKK Add             add_6
095 + |                                                                                            | RESULT float32  3:2x1024x1           HHVU Sqrt            val_324
096 ~ | RESULT float32  3:2x1024x1           OOBM SimplifiedLayer saved_inv_std_var_token_10       | RESULT float32  3:2x1024x1           OOBM Reciprocal      rsqrt_1
097 + |                                                                                            | RESULT float32  3:2x1024x512         NTCA Mul             mul_10
098 + |                                                                                            | RESULT float32  3:2x1024x512         NTCA Mul             mul_11
099 = | RESULT float32  3:2x1024x2000        EHYE MatMul          /model/layers.0/mlp/gate_proj/Ma | RESULT float32  3:2x1024x2000        EHYE MatMul          linear_4
100 = | RESULT float32  3:2x1024x2000        JNTL QuickGelu       /model/layers.0/mlp/act_fn/Mul_o | RESULT float32  3:2x1024x2000        JNTL QuickGelu       silu
101 = | RESULT float32  3:2x1024x2000        YVQD MatMul          /model/layers.0/mlp/up_proj/MatM | RESULT float32  3:2x1024x2000        YVQD MatMul          linear_5
102 = | RESULT float32  3:2x1024x2000        BJDR Mul             /model/layers.0/mlp/Mul_output_0 | RESULT float32  3:2x1024x2000        BJDR Mul             mul_12
103 = | RESULT float32  3:2x1024x512         OJMD MatMul          /model/layers.0/mlp/down_proj/Ma | RESULT float32  3:2x1024x512         OJMD MatMul          linear_6
104 = | RESULT float32  3:2x1024x512         DGHG Add             /model/layers.0/Add_1_output_0   | RESULT float32  3:2x1024x512         DGHG Add             add_7
105 ~ | RESULT float32  3:2x1024x512         BOMY SimplifiedLayer 347                              | RESULT float32  3:2x1024x512         NJNL Pow             pow_3
106 + |                                                                                            | RESULT float32  3:2x1024x1           BBPP ReduceMean      mean_2
107 + |                                                                                            | RESULT float32  3:2x1024x1           BBPP Add             add_8
108 + |                                                                                            | RESULT float32  3:2x1024x1           OOKK Sqrt            val_331
109 ~ | RESULT float32  3:2x1024x1           BBDI SimplifiedLayer saved_inv_std_var_token_11       | RESULT float32  3:2x1024x1           BBDI Reciprocal      rsqrt_2
110 + |                                                                                            | RESULT float32  3:2x1024x512         BOMY Mul             mul_13
111 + |                                                                                            | RESULT float32  3:2x1024x512         BOMY Mul             mul_14
112 = | OUTPUT float32  3:2x1024x512         BOMY                 347                              | OUTPUT float32  3:2x1024x512         BOMY                 mul_14

See plot_llama_diff_export for a better view.

Total running time of the script: (0 minutes 54.755 seconds)

Related examples

201: Evaluate different ways to export a torch model to ONNX

201: Evaluate different ways to export a torch model to ONNX

301: Compares LLAMA exporters for onnxrt backend

301: Compares LLAMA exporters for onnxrt backend

102: Fuse kernels in a small Llama Model

102: Fuse kernels in a small Llama Model

201: Evaluate DORT Training

201: Evaluate DORT Training

201: Use torch to export a scikit-learn model into ONNX

201: Use torch to export a scikit-learn model into ONNX

Gallery generated by Sphinx-Gallery