Gemm Exploration with CUDA

One big Gemm or two smaller gemm?

Cache Performance

from onnx_extended.args import get_parsed_args

script_args = get_parsed_args(
    "plot_op_gemm2_cuda",
    description=__doc__,
    config=(
        "small",
        "small, short optimization (default), "
        "medium for medium sizes, "
        "large for big sizes",
    ),
    warmup=3,
    repeat=5,
    itype=(1, "1 or 10 for float or float16"),
    expose="config,itype,warmup,repeat",
)

itype = script_args.itype
config = script_args.config
print(f"config={config}")
print(f"itype={itype}")

if config == "small":
    sizes = (256, 512, 1024)
elif config == "medium":
    sizes = (512, 1024, 2048)
elif config == "large":
    sizes = (1024, 2048, 4096, 8192)
else:
    try:
        sizes = list(map(int, config.split(",")))
    except (ValueError, TypeError) as e:
        raise AssertionError(f"Unexpected config value {config!r}.") from e

import time
import numpy as np
import onnx.helper as oh
from tqdm import tqdm
from pandas import DataFrame
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
from onnx_extended.ortops.optim.cuda import get_ort_ext_libs


def get_model1(itype):
    return oh.make_model(
        oh.make_graph(
            [
                oh.make_node("Gemm", ["X", "Y"], ["XY"]),
                oh.make_node("Gemm", ["X", "Z"], ["XZ"]),
                oh.make_node("Concat", ["XY", "XZ"], ["XYZ"], axis=1),
            ],
            "nd",
            [
                oh.make_tensor_value_info("X", itype, [None, None]),
                oh.make_tensor_value_info("Y", itype, [None, None]),
                oh.make_tensor_value_info("Z", itype, [None, None]),
            ],
            [oh.make_tensor_value_info("XYZ", itype, [None, None])],
        ),
        opset_imports=[oh.make_opsetid("", 18)],
        ir_version=9,
    )


print(onnx_simple_text_plot(get_model1(itype)))
config=small
itype=1
[2024-05-21 14:59:08,864] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
opset: domain='' version=18
input: name='X' type=dtype('float32') shape=['', '']
input: name='Y' type=dtype('float32') shape=['', '']
input: name='Z' type=dtype('float32') shape=['', '']
Gemm(X, Y) -> XY
Gemm(X, Z) -> XZ
  Concat(XY, XZ, axis=1) -> XYZ
output: name='XYZ' type=dtype('float32') shape=['', '']

And the other model

def get_model2(itype):
    return oh.make_model(
        oh.make_graph(
            [
                oh.make_node("Concat", ["Y", "Z"], ["YZ"], axis=1),
                oh.make_node("Gemm", ["X", "YZ"], ["XYZ"]),
            ],
            "nd",
            [
                oh.make_tensor_value_info("X", itype, [None, None]),
                oh.make_tensor_value_info("Y", itype, [None, None]),
                oh.make_tensor_value_info("Z", itype, [None, None]),
            ],
            [oh.make_tensor_value_info("XYZ", itype, [None, None])],
        ),
        opset_imports=[oh.make_opsetid("", 18)],
        ir_version=9,
    )


print(onnx_simple_text_plot(get_model2(itype)))
opset: domain='' version=18
input: name='X' type=dtype('float32') shape=['', '']
input: name='Y' type=dtype('float32') shape=['', '']
input: name='Z' type=dtype('float32') shape=['', '']
Concat(Y, Z, axis=1) -> YZ
  Gemm(X, YZ) -> XYZ
output: name='XYZ' type=dtype('float32') shape=['', '']

InferenceSession

has_cuda = "CUDAExecutionProvider" in get_available_providers()

if has_cuda:

    dtype = np.float32 if itype == 1 else np.float16

    x = np.random.randn(16, 16).astype(dtype)
    y = np.random.randn(16, 16).astype(dtype)
    z = np.random.randn(16, 16).astype(dtype)
    feeds = dict(X=x, Y=y, Z=z)

    sess1 = InferenceSession(
        get_model1(itype).SerializeToString(), providers=["CUDAExecutionProvider"]
    )
    expected = sess1.run(None, feeds)[0]

The other model.

if has_cuda:

    opts = SessionOptions()
    opts.register_custom_ops_library(get_ort_ext_libs()[0])

    sess2 = InferenceSession(
        get_model2(itype).SerializeToString(), opts, providers=["CUDAExecutionProvider"]
    )
    got = sess2.run(None, feeds)[0]

Discrepancies

if has_cuda:

    diff = np.abs(got - expected).max()
    print(f"diff={diff}")
diff=0.0

Benchmark

some code to avoid measuring copying the data from host to device

def move_inputs(sess, feeds):
    from onnxruntime.capi._pybind_state import (
        SessionIOBinding,
        OrtDevice as C_OrtDevice,
        OrtValue as C_OrtValue,
    )

    input_names = [i.name for i in sess.get_inputs()]

    ort_device = C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0)

    feed_ort_value = [
        (name, C_OrtValue.ortvalue_from_numpy(feeds[name], ort_device))
        for name in input_names
    ]

    bind = SessionIOBinding(sess._sess)
    for name, value in feed_ort_value:
        bind.bind_input(
            name, ort_device, feeds[name].dtype, value.shape(), value.data_ptr()
        )
    for o in sess.get_outputs():
        bind.bind_output(o.name, ort_device)
    return bind, feed_ort_value

Benchmark function

def benchmark(sess, sizes, label):

    data = []
    for size in tqdm(sizes):

        x = np.random.randn(size, size).astype(dtype)
        y = np.random.randn(size, size).astype(dtype)
        z = np.random.randn(size, size).astype(dtype)
        feeds = dict(X=x, Y=y, Z=z)
        bind, cuda_feeds = move_inputs(sess, feeds)

        begin = time.perf_counter()
        for i in range(script_args.warmup):
            # sess.run(None, feeds)
            sess._sess.run_with_iobinding(bind, None)
        warmup = time.perf_counter() - begin

        times = []
        for i in range(script_args.repeat):
            begin = time.perf_counter()
            # sess.run(None, feeds)
            sess._sess.run_with_iobinding(bind, None)
            times.append(time.perf_counter() - begin)

        npt = np.array(times)
        obs = dict(
            warmup=warmup,
            time=npt.mean(),
            std=npt.std(),
            min=npt.min(),
            max=npt.max(),
            repeat=script_args.repeat,
            size=size,
            label=label,
        )
        data.append(obs)
    return data

Not Fused.

if has_cuda:

    print(f"sizes={sizes}")

    data_mul = benchmark(sess1, sizes, "Not Fused")
sizes=(256, 512, 1024)

  0%|          | 0/3 [00:00<?, ?it/s]
 67%|██████▋   | 2/3 [00:00<00:00, 17.50it/s]
100%|██████████| 3/3 [00:00<00:00,  5.95it/s]

Fused.

if has_cuda:

    data_mulmul = benchmark(sess2, sizes, "Fused")
  0%|          | 0/3 [00:00<?, ?it/s]
100%|██████████| 3/3 [00:00<00:00, 22.90it/s]
100%|██████████| 3/3 [00:00<00:00, 22.85it/s]

Data

if has_cuda:

    df = DataFrame(data_mul + data_mulmul)
    df.to_csv("plot_op_gemm2_cuda.csv", index=False)
    df.to_csv("plot_op_gemm2_cuda.xlsx", index=False)
    print(df.head())
     warmup      time       std       min       max  repeat  size      label
0  0.002893  0.000871  0.000032  0.000851  0.000934       5   256  Not Fused
1  0.017671  0.003964  0.000080  0.003858  0.004046       5   512  Not Fused
2  0.076785  0.015037  0.010565  0.002347  0.025548       5  1024  Not Fused
3  0.000523  0.000125  0.000020  0.000113  0.000165       5   256      Fused
4  0.002584  0.000399  0.000006  0.000392  0.000409       5   512      Fused

Pivot.

if has_cuda:

    pivot = df.pivot(index="size", columns="label", values="time")
    pivot["ratio"] = pivot["Fused"] / pivot["Not Fused"]
    print(pivot)

    ax = pivot[["Not Fused", "Fused"]].plot(
        logx=True,
        logy=True,
        title=f"Fused/Unfused element wise multiplication on CUDA\nitype={itype}",
    )
    ax.get_figure().savefig("plot_op_gemm2_cuda.png")
Fused/Unfused element wise multiplication on CUDA itype=1
label     Fused  Not Fused     ratio
size
256    0.000125   0.000871  0.143943
512    0.000399   0.003964  0.100652
1024   0.002087   0.015037  0.138787

It seems the fused operator is 33% faster.

Total running time of the script: (0 minutes 43.029 seconds)

Gallery generated by Sphinx-Gallery