Measuring Gemm performance with different input and output tests

This benchmark looks into various combinations allowed by functions cublasLtMatmul. The tested configurations are available at cuda_gemm.cu.

import pprint
import warnings
from itertools import product
from tqdm import tqdm
import matplotlib.pyplot as plt
from pandas import DataFrame
from onnx_extended.args import get_parsed_args
from onnx_extended.ext_test_case import unit_test_going

try:
    from onnx_extended.validation.cuda.cuda_example_py import (
        gemm_benchmark_test,
        get_device_prop,
    )

    has_cuda = True
except ImportError:
    # CUDA not available.
    has_cuda = False
    gemm_benchmark_test = None

if has_cuda:
    prop = get_device_prop()
    if prop["major"] <= 0:
        # No CUDA.
        dtests, ddims = "", ""
    elif prop["major"] < 7:
        # No float 8.
        dtests, ddims = "0,1,2,3,4,15", "16,32,64,64x128x92"
    elif prop["major"] < 9:  # T100, A100
        # No float 8.
        dtests, ddims = (
            "0,1,2,3,4,15",
            "16,32,64,128,256,512,1024,2048,4096,8192,"
            "128x768x768,128x3072x768,128x768x3072",
        )
    else:
        dtests, ddims = (
            "0,1,2,3,4,5,6,7,11,14,15",
            "16,32,64,128,256,512,1024,2048,4096,8192,16384,"
            "128x768x768,128x3072x768,128x768x3072",
        )
else:
    dtests, ddims = "", ""


script_args = get_parsed_args(
    "plot_bench_gemm_f8",
    description=__doc__,
    dims=(
        "16,32" if unit_test_going() else ddims,
        "square matrix dimensions to try, comma separated values",
    ),
    tests=(
        "0,1,2" if unit_test_going() else dtests,
        "configuration to check, see cuda_gemm.cu",
    ),
    warmup=2 if unit_test_going() else 5,
    repeat=2 if unit_test_going() else 10,
    expose="repeat,warmup",
)

Device

if has_cuda:
    prop = get_device_prop()
    pprint.pprint(prop)
else:
    print("CUDA is not available")
    prop = dict(major=0)
{'clockRate': 1569000,
 'computeMode': 0,
 'concurrentKernels': 1,
 'isMultiGpuBoard': 0,
 'major': 6,
 'maxThreadsPerBlock': 1024,
 'minor': 1,
 'multiProcessorCount': 10,
 'name': 'NVIDIA GeForce GTX 1060',
 'sharedMemPerBlock': 49152,
 'totalConstMem': 65536,
 'totalGlobalMem': 6442319872}

Benchmark

def type2string(dt):
    dtests = {
        0: "F32",
        2: "F16",
        14: "BF16",
        28: "E4M3",
        29: "E5M2",
        3: "I8",
        10: "I32",
    }
    return dtests[int(dt)]


dims = []
tests = []
if gemm_benchmark_test is not None:
    for d in script_args.dims.split(","):
        if "x" in d:
            spl = d.split("x")
            m, n, k = tuple(int(i) for i in spl)
            dims.append((m, n, k))
        else:
            dims.append(int(d))
    tests = list(int(i) for i in script_args.tests.split(","))

pbar = tqdm(list(product(tests, dims)))
obs = []
for test, dim in pbar:
    pbar.set_description(f"type={test} dim={dim}")
    if test in {8, 9, 10, 12, 13}:
        warnings.warn(f"unsupported configuration {test}.")
        continue
    mdim = dim if isinstance(dim, int) else max(dim)
    if mdim < 128:
        n, N = script_args.warmup * 8, script_args.repeat * 8
    elif mdim < 512:
        n, N = script_args.warmup * 4, script_args.repeat * 4
    elif mdim < 8192:
        n, N = script_args.warmup * 2, script_args.repeat * 2
    else:
        n, N = script_args.warmup, script_args.repeat

    if isinstance(dim, int):
        gemm_args = [dim] * 6
    else:
        m, n, k = dim
        lda, ldb, ldd = k, k, k
        gemm_args = [m, n, k, lda, ldb, ldd]

    # warmup
    gemm_benchmark_test(test, N, *gemm_args)

    # benchmark
    res = gemm_benchmark_test(test, N, *gemm_args)

    # better rendering
    res["test"] = test
    update = {}
    for k, v in res.items():
        if "type_" in k:
            update[k] = type2string(v)
        if k.startswith("t-"):
            update[k] = res[k] / res["N"]
    update["compute_type"] = f"C{int(res['compute_type'])}"
    for c in ["N", "m", "n", "k", "lda", "ldb", "ldd"]:
        update[c] = int(res[c])
    update["~dim"] = (update["k"] * max(update["m"], update["n"])) ** 0.5
    update["mnk"] = f"{update['m']}x{update['n']}x{update['k']}"
    update["name"] = (
        f"{update['type_a']}x{update['type_b']}->"
        f"{update['type_d']}{update['compute_type']}"
    )
    res.update(update)
    obs.append(res)
    if unit_test_going() and len(obs) > 2:
        break

df = DataFrame(obs)
df.to_csv("plot_bench_gemm_f8.csv", index=False)
df.to_excel("plot_bench_gemm_f8.xlsx", index=False)
print(df.head().T)

df.head().T
  0%|          | 0/24 [00:00<?, ?it/s]
type=0 dim=16:   0%|          | 0/24 [00:00<?, ?it/s]
type=0 dim=16:   4%|▍         | 1/24 [00:01<00:37,  1.61s/it]
type=0 dim=32:   4%|▍         | 1/24 [00:01<00:37,  1.61s/it]
type=0 dim=64:   4%|▍         | 1/24 [00:01<00:37,  1.61s/it]
type=0 dim=64:  12%|█▎        | 3/24 [00:01<00:09,  2.19it/s]
type=0 dim=(64, 128, 92):  12%|█▎        | 3/24 [00:01<00:09,  2.19it/s]
type=1 dim=16:  12%|█▎        | 3/24 [00:01<00:09,  2.19it/s]
type=1 dim=32:  12%|█▎        | 3/24 [00:01<00:09,  2.19it/s]
type=1 dim=32:  25%|██▌       | 6/24 [00:01<00:03,  4.84it/s]
type=1 dim=64:  25%|██▌       | 6/24 [00:01<00:03,  4.84it/s]
type=1 dim=(64, 128, 92):  25%|██▌       | 6/24 [00:01<00:03,  4.84it/s]
type=1 dim=(64, 128, 92):  33%|███▎      | 8/24 [00:01<00:02,  6.68it/s]
type=2 dim=16:  33%|███▎      | 8/24 [00:01<00:02,  6.68it/s]
type=2 dim=32:  33%|███▎      | 8/24 [00:02<00:02,  6.68it/s]
type=2 dim=32:  42%|████▏     | 10/24 [00:02<00:01,  8.67it/s]
type=2 dim=64:  42%|████▏     | 10/24 [00:02<00:01,  8.67it/s]
type=2 dim=(64, 128, 92):  42%|████▏     | 10/24 [00:02<00:01,  8.67it/s]
type=2 dim=(64, 128, 92):  50%|█████     | 12/24 [00:02<00:01, 10.52it/s]
type=3 dim=16:  50%|█████     | 12/24 [00:02<00:01, 10.52it/s]
type=3 dim=32:  50%|█████     | 12/24 [00:02<00:01, 10.52it/s]
type=3 dim=32:  58%|█████▊    | 14/24 [00:02<00:01,  5.44it/s]
type=3 dim=64:  58%|█████▊    | 14/24 [00:02<00:01,  5.44it/s]
type=3 dim=(64, 128, 92):  58%|█████▊    | 14/24 [00:03<00:01,  5.44it/s]
type=3 dim=(64, 128, 92):  67%|██████▋   | 16/24 [00:04<00:03,  2.40it/s]
type=4 dim=16:  67%|██████▋   | 16/24 [00:04<00:03,  2.40it/s]
type=4 dim=32:  67%|██████▋   | 16/24 [00:04<00:03,  2.40it/s]
type=4 dim=64:  67%|██████▋   | 16/24 [00:04<00:03,  2.40it/s]
type=4 dim=64:  79%|███████▉  | 19/24 [00:04<00:01,  3.70it/s]
type=4 dim=(64, 128, 92):  79%|███████▉  | 19/24 [00:04<00:01,  3.70it/s]
type=15 dim=16:  79%|███████▉  | 19/24 [00:04<00:01,  3.70it/s]
type=15 dim=32:  79%|███████▉  | 19/24 [00:05<00:01,  3.70it/s]
type=15 dim=32:  92%|█████████▏| 22/24 [00:05<00:00,  5.30it/s]
type=15 dim=64:  92%|█████████▏| 22/24 [00:05<00:00,  5.30it/s]
type=15 dim=(64, 128, 92):  92%|█████████▏| 22/24 [00:05<00:00,  5.30it/s]
type=15 dim=(64, 128, 92): 100%|██████████| 24/24 [00:05<00:00,  4.66it/s]
                                0                1                2                3                4
t-total                  0.000143         0.000132         0.000149         0.000183         0.000093
t-clean                  0.000001         0.000001         0.000001         0.000001         0.000001
t-gemm_in                0.000027         0.000011         0.000009         0.000009         0.000008
t-setup                  0.000008         0.000012         0.000009         0.000009         0.000005
t-stream_create               0.0              0.0              0.0              0.0              0.0
N                              80               80               80               40               80
epiloque                      1.0              1.0              1.0              1.0              1.0
ldd                            16               32               64               92               16
t-workspace_free         0.000005         0.000004         0.000003         0.000003         0.000003
algo                         11.0              0.0              0.0              0.0             11.0
t-gemm_sync              0.000095          0.00012          0.00014         0.000174         0.000084
t-stream_destroy         0.000036         0.000002         0.000001         0.000001         0.000001
workspace_size          1048576.0        1048576.0        1048576.0        1048576.0        1048576.0
m                              16               32               64               64               16
k                              16               32               64               92               16
n                              16               32               64              128               16
compute_type                  C68              C68              C68              C68              C77
lda                            16               32               64               92               16
type_a                        F32              F32              F32              F32              F32
ldb                            16               32               64               92               16
t-gemm                   0.000037         0.000025          0.00002          0.00002         0.000015
type_b                        F32              F32              F32              F32              F32
t-workspace_new          0.000003         0.000004         0.000003         0.000002         0.000002
type_d                        F32              F32              F32              F32              F32
test                            0                0                0                0                1
~dim                         16.0             32.0             64.0        108.51728             16.0
mnk                      16x16x16         32x32x32         64x64x64        64x128x92         16x16x16
name              F32xF32->F32C68  F32xF32->F32C68  F32xF32->F32C68  F32xF32->F32C68  F32xF32->F32C77
0 1 2 3 4
t-total 0.000143 0.000132 0.000149 0.000183 0.000093
t-clean 0.000001 0.000001 0.000001 0.000001 0.000001
t-gemm_in 0.000027 0.000011 0.000009 0.000009 0.000008
t-setup 0.000008 0.000012 0.000009 0.000009 0.000005
t-stream_create 0.0 0.0 0.0 0.0 0.0
N 80 80 80 40 80
epiloque 1.0 1.0 1.0 1.0 1.0
ldd 16 32 64 92 16
t-workspace_free 0.000005 0.000004 0.000003 0.000003 0.000003
algo 11.0 0.0 0.0 0.0 11.0
t-gemm_sync 0.000095 0.00012 0.00014 0.000174 0.000084
t-stream_destroy 0.000036 0.000002 0.000001 0.000001 0.000001
workspace_size 1048576.0 1048576.0 1048576.0 1048576.0 1048576.0
m 16 32 64 64 16
k 16 32 64 92 16
n 16 32 64 128 16
compute_type C68 C68 C68 C68 C77
lda 16 32 64 92 16
type_a F32 F32 F32 F32 F32
ldb 16 32 64 92 16
t-gemm 0.000037 0.000025 0.00002 0.00002 0.000015
type_b F32 F32 F32 F32 F32
t-workspace_new 0.000003 0.000004 0.000003 0.000002 0.000002
type_d F32 F32 F32 F32 F32
test 0 0 0 0 1
~dim 16.0 32.0 64.0 108.51728 16.0
mnk 16x16x16 32x32x32 64x64x64 64x128x92 16x16x16
name F32xF32->F32C68 F32xF32->F32C68 F32xF32->F32C68 F32xF32->F32C68 F32xF32->F32C77


Test definition

col_def = ["name", "test", "type_a", "type_b", "type_d", "compute_type"]
if df.shape[0] > 0:
    deft = df.copy()
    gr = deft[col_def].groupby(col_def, as_index=False).count()
    print(gr)
                 name  test type_a type_b type_d compute_type
0  BF16xBF16->BF16C68     4   BF16   BF16   BF16          C68
1     F16xF16->F16C64     3    F16    F16    F16          C64
2     F32xF32->F32C68     0    F32    F32    F32          C68
3     F32xF32->F32C75     2    F32    F32    F32          C75
4     F32xF32->F32C77     1    F32    F32    F32          C77
5       I8xI8->I32C72    15     I8     I8    I32          C72

Total time and only gemm

if df.shape[0] > 0:
    dfi = df[col_def + ["~dim", "mnk", "t-total", "t-gemm_sync"]]
    print(dfi)
                  name  test type_a type_b type_d compute_type       ~dim        mnk   t-total  t-gemm_sync
0      F32xF32->F32C68     0    F32    F32    F32          C68   16.00000   16x16x16  0.000143     0.000095
1      F32xF32->F32C68     0    F32    F32    F32          C68   32.00000   32x32x32  0.000132     0.000120
2      F32xF32->F32C68     0    F32    F32    F32          C68   64.00000   64x64x64  0.000149     0.000140
3      F32xF32->F32C68     0    F32    F32    F32          C68  108.51728  64x128x92  0.000183     0.000174
4      F32xF32->F32C77     1    F32    F32    F32          C77   16.00000   16x16x16  0.000093     0.000084
5      F32xF32->F32C77     1    F32    F32    F32          C77   32.00000   32x32x32  0.000154     0.000139
6      F32xF32->F32C77     1    F32    F32    F32          C77   64.00000   64x64x64  0.000159     0.000149
7      F32xF32->F32C77     1    F32    F32    F32          C77  108.51728  64x128x92  0.000194     0.000183
8      F32xF32->F32C75     2    F32    F32    F32          C75   16.00000   16x16x16  0.000090     0.000081
9      F32xF32->F32C75     2    F32    F32    F32          C75   32.00000   32x32x32  0.000170     0.000152
10     F32xF32->F32C75     2    F32    F32    F32          C75   64.00000   64x64x64  0.000158     0.000148
11     F32xF32->F32C75     2    F32    F32    F32          C75  108.51728  64x128x92  0.000195     0.000184
12     F16xF16->F16C64     3    F16    F16    F16          C64   16.00000   16x16x16  0.001053     0.001020
13     F16xF16->F16C64     3    F16    F16    F16          C64   32.00000   32x32x32  0.003493     0.003417
14     F16xF16->F16C64     3    F16    F16    F16          C64   64.00000   64x64x64  0.006469     0.006414
15     F16xF16->F16C64     3    F16    F16    F16          C64  108.51728  64x128x92  0.009382     0.009356
16  BF16xBF16->BF16C68     4   BF16   BF16   BF16          C68   16.00000   16x16x16  0.000115     0.000106
17  BF16xBF16->BF16C68     4   BF16   BF16   BF16          C68   32.00000   32x32x32  0.000237     0.000228
18  BF16xBF16->BF16C68     4   BF16   BF16   BF16          C68   64.00000   64x64x64  0.000296     0.000286
19  BF16xBF16->BF16C68     4   BF16   BF16   BF16          C68  108.51728  64x128x92  0.000401     0.000385
20       I8xI8->I32C72    15     I8     I8    I32          C72   16.00000   16x16x16  0.000112     0.000103
21       I8xI8->I32C72    15     I8     I8    I32          C72   32.00000   32x32x32  0.000191     0.000180
22       I8xI8->I32C72    15     I8     I8    I32          C72   64.00000   64x64x64  0.000187     0.000178
23       I8xI8->I32C72    15     I8     I8    I32          C72  108.51728  64x128x92  0.000249     0.000203

Smaller sets

if df.shape[0] > 0:
    subset = {1, 3, 4, 5, 7}
    dfis = dfi[dfi.test.isin(subset)]
    print()
    print("t-gemm_sync")
    pivi = dfis.pivot_table(index=["~dim", "mnk"], columns="name", values="t-gemm_sync")
    print(pivi)
    print()
    print("t-total")
    pivi = dfis.pivot_table(index=["~dim", "mnk"], columns="name", values="t-total")
    print(pivi)
t-gemm_sync
name                 BF16xBF16->BF16C68  F16xF16->F16C64  F32xF32->F32C77
~dim      mnk
16.00000  16x16x16             0.000106         0.001020         0.000084
32.00000  32x32x32             0.000228         0.003417         0.000139
64.00000  64x64x64             0.000286         0.006414         0.000149
108.51728 64x128x92            0.000385         0.009356         0.000183

t-total
name                 BF16xBF16->BF16C68  F16xF16->F16C64  F32xF32->F32C77
~dim      mnk
16.00000  16x16x16             0.000115         0.001053         0.000093
32.00000  32x32x32             0.000237         0.003493         0.000154
64.00000  64x64x64             0.000296         0.006469         0.000159
108.51728 64x128x92            0.000401         0.009382         0.000194

Plots

if df.shape[0] > 0:
    piv = df.pivot_table(index=["~dim", "mnk"], columns="name", values="t-gemm_sync")
    piv.plot(title="MatMul performances")

    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    piv.plot(ax=ax[0], title="Gemm performance\nlower is better", logx=True, logy=True)

    piv = df[df.test.isin(subset)].pivot_table(
        index=["~dim", "mnk"], columns="name", values="t-gemm_sync"
    )
    if piv.shape[0] > 0:
        piv.plot(
            ax=ax[1], title="Gemm performance\nlower is better", logx=True, logy=True
        )

    fig.tight_layout()
    fig.savefig("plot_bench_gemm_f8.png")
  • MatMul performances
  • Gemm performance lower is better, Gemm performance lower is better

Total running time of the script: (0 minutes 6.638 seconds)

Gallery generated by Sphinx-Gallery