Measuring Gemm performance with different input and output tests#

This benchmark looks into various combinations allowed by functions cublasLtMatmul. The tested configurations are available at cuda_gemm.cu.

import pprint
import warnings
from itertools import product
from tqdm import tqdm
import matplotlib.pyplot as plt
from pandas import DataFrame
from onnx_extended.ext_test_case import unit_test_going, get_parsed_args

try:
    from onnx_extended.validation.cuda.cuda_example_py import (
        gemm_benchmark_test,
        get_device_prop,
    )

    has_cuda = True
except ImportError:
    # CUDA not available.
    has_cuda = False
    gemm_benchmark_test = None

if has_cuda:
    prop = get_device_prop()
    if prop["major"] <= 0:
        # No CUDA.
        dtests, ddims = "", ""
    elif prop["major"] < 7:
        # No float 8.
        dtests, ddims = "0,1,2,3,4,15", "16,32,64,64x128x92"
    elif prop["major"] < 9:  # T100, A100
        # No float 8.
        dtests, ddims = (
            "0,1,2,3,4,15",
            "16,32,64,128,256,512,1024,2048,4096,8192,"
            "128x768x768,128x3072x768,128x768x3072",
        )
    else:
        dtests, ddims = (
            "0,1,2,3,4,5,6,7,11,14,15",
            "16,32,64,128,256,512,1024,2048,4096,8192,16384,"
            "128x768x768,128x3072x768,128x768x3072",
        )
else:
    dtests, ddims = "", ""


script_args = get_parsed_args(
    "plot_bench_gemm_f8",
    description=__doc__,
    dims=(
        "16,32" if unit_test_going() else ddims,
        "square matrix dimensions to try, comma separated values",
    ),
    tests=(
        "0,1,2" if unit_test_going() else dtests,
        "configuration to check, see cuda_gemm.cu",
    ),
    warmup=2 if unit_test_going() else 5,
    repeat=2 if unit_test_going() else 10,
    expose="repeat,warmup",
)

Device#

if has_cuda:
    prop = get_device_prop()
    pprint.pprint(prop)
else:
    print("CUDA is not available")
    prop = dict(major=0)
{'clockRate': 1569000,
 'computeMode': 0,
 'concurrentKernels': 1,
 'isMultiGpuBoard': 0,
 'major': 6,
 'maxThreadsPerBlock': 1024,
 'minor': 1,
 'multiProcessorCount': 10,
 'name': 'NVIDIA GeForce GTX 1060',
 'sharedMemPerBlock': 49152,
 'totalConstMem': 65536,
 'totalGlobalMem': 6442319872}

Benchmark#

def type2string(dt):
    dtests = {
        0: "F32",
        2: "F16",
        14: "BF16",
        28: "E4M3",
        29: "E5M2",
        3: "I8",
        10: "I32",
    }
    return dtests[int(dt)]


dims = []
tests = []
if gemm_benchmark_test is not None:
    for d in script_args.dims.split(","):
        if "x" in d:
            spl = d.split("x")
            m, n, k = tuple(int(i) for i in spl)
            dims.append((m, n, k))
        else:
            dims.append(int(d))
    tests = list(int(i) for i in script_args.tests.split(","))

pbar = tqdm(list(product(tests, dims)))
obs = []
for test, dim in pbar:
    pbar.set_description(f"type={test} dim={dim}")
    if test in {8, 9, 10, 12, 13}:
        warnings.warn(f"unsupported configuration {test}.")
        continue
    mdim = dim if isinstance(dim, int) else max(dim)
    if mdim < 128:
        n, N = script_args.warmup * 8, script_args.repeat * 8
    elif mdim < 512:
        n, N = script_args.warmup * 4, script_args.repeat * 4
    elif mdim < 8192:
        n, N = script_args.warmup * 2, script_args.repeat * 2
    else:
        n, N = script_args.warmup, script_args.repeat

    if isinstance(dim, int):
        gemm_args = [dim] * 6
    else:
        m, n, k = dim
        lda, ldb, ldd = k, k, k
        gemm_args = [m, n, k, lda, ldb, ldd]

    # warmup
    gemm_benchmark_test(test, N, *gemm_args)

    # benchmark
    res = gemm_benchmark_test(test, N, *gemm_args)

    # better rendering
    res["test"] = test
    update = {}
    for k, v in res.items():
        if "type_" in k:
            update[k] = type2string(v)
        if k.startswith("t-"):
            update[k] = res[k] / res["N"]
    update["compute_type"] = f"C{int(res['compute_type'])}"
    for c in ["N", "m", "n", "k", "lda", "ldb", "ldd"]:
        update[c] = int(res[c])
    update["~dim"] = (update["k"] * max(update["m"], update["n"])) ** 0.5
    update["mnk"] = f"{update['m']}x{update['n']}x{update['k']}"
    update["name"] = (
        f"{update['type_a']}x{update['type_b']}->"
        f"{update['type_d']}{update['compute_type']}"
    )
    res.update(update)
    obs.append(res)
    if unit_test_going() and len(obs) > 2:
        break

df = DataFrame(obs)
df.to_csv("plot_bench_gemm_f8.csv", index=False)
df.to_excel("plot_bench_gemm_f8.xlsx", index=False)
print(df.head().T)

df.head().T
  0%|          | 0/24 [00:00<?, ?it/s]
type=0 dim=16:   0%|          | 0/24 [00:00<?, ?it/s]
type=0 dim=16:   4%|▍         | 1/24 [00:06<02:19,  6.07s/it]
type=0 dim=32:   4%|▍         | 1/24 [00:06<02:19,  6.07s/it]
type=0 dim=32:   8%|▊         | 2/24 [00:06<00:56,  2.56s/it]
type=0 dim=64:   8%|▊         | 2/24 [00:06<00:56,  2.56s/it]
type=0 dim=(64, 128, 92):   8%|▊         | 2/24 [00:06<00:56,  2.56s/it]
type=0 dim=(64, 128, 92):  17%|█▋        | 4/24 [00:06<00:20,  1.01s/it]
type=1 dim=16:  17%|█▋        | 4/24 [00:06<00:20,  1.01s/it]
type=1 dim=32:  17%|█▋        | 4/24 [00:06<00:20,  1.01s/it]
type=1 dim=32:  25%|██▌       | 6/24 [00:06<00:10,  1.76it/s]
type=1 dim=64:  25%|██▌       | 6/24 [00:06<00:10,  1.76it/s]
type=1 dim=(64, 128, 92):  25%|██▌       | 6/24 [00:06<00:10,  1.76it/s]
type=1 dim=(64, 128, 92):  33%|███▎      | 8/24 [00:06<00:05,  2.72it/s]
type=2 dim=16:  33%|███▎      | 8/24 [00:06<00:05,  2.72it/s]
type=2 dim=32:  33%|███▎      | 8/24 [00:06<00:05,  2.72it/s]
type=2 dim=32:  42%|████▏     | 10/24 [00:06<00:03,  3.72it/s]
type=2 dim=64:  42%|████▏     | 10/24 [00:06<00:03,  3.72it/s]
type=2 dim=(64, 128, 92):  42%|████▏     | 10/24 [00:06<00:03,  3.72it/s]
type=2 dim=(64, 128, 92):  50%|█████     | 12/24 [00:07<00:02,  4.82it/s]
type=3 dim=16:  50%|█████     | 12/24 [00:07<00:02,  4.82it/s]
type=3 dim=32:  50%|█████     | 12/24 [00:07<00:02,  4.82it/s]
type=3 dim=32:  58%|█████▊    | 14/24 [00:07<00:02,  3.74it/s]
type=3 dim=64:  58%|█████▊    | 14/24 [00:07<00:02,  3.74it/s]
type=3 dim=64:  62%|██████▎   | 15/24 [00:08<00:03,  2.43it/s]
type=3 dim=(64, 128, 92):  62%|██████▎   | 15/24 [00:08<00:03,  2.43it/s]
type=3 dim=(64, 128, 92):  67%|██████▋   | 16/24 [00:09<00:03,  2.04it/s]
type=4 dim=16:  67%|██████▋   | 16/24 [00:09<00:03,  2.04it/s]
type=4 dim=32:  67%|██████▋   | 16/24 [00:09<00:03,  2.04it/s]
type=4 dim=32:  75%|███████▌  | 18/24 [00:09<00:01,  3.07it/s]
type=4 dim=64:  75%|███████▌  | 18/24 [00:09<00:01,  3.07it/s]
type=4 dim=(64, 128, 92):  75%|███████▌  | 18/24 [00:09<00:01,  3.07it/s]
type=4 dim=(64, 128, 92):  83%|████████▎ | 20/24 [00:09<00:00,  4.21it/s]
type=15 dim=16:  83%|████████▎ | 20/24 [00:09<00:00,  4.21it/s]
type=15 dim=32:  83%|████████▎ | 20/24 [00:09<00:00,  4.21it/s]
type=15 dim=64:  83%|████████▎ | 20/24 [00:09<00:00,  4.21it/s]
type=15 dim=64:  96%|█████████▌| 23/24 [00:10<00:00,  6.34it/s]
type=15 dim=(64, 128, 92):  96%|█████████▌| 23/24 [00:10<00:00,  6.34it/s]
type=15 dim=(64, 128, 92): 100%|██████████| 24/24 [00:10<00:00,  2.39it/s]
                                0  ...                4
t-total                  0.000257  ...          0.00024
t-clean                  0.000002  ...         0.000005
t-gemm_in                0.000041  ...         0.000048
t-setup                  0.000025  ...          0.00003
t-stream_create               0.0  ...              0.0
N                              80  ...               80
epiloque                      1.0  ...              1.0
ldd                            16  ...               16
t-workspace_free         0.000017  ...         0.000016
algo                         11.0  ...             11.0
t-gemm_sync              0.000218  ...         0.000198
t-stream_destroy         0.000005  ...         0.000007
workspace_size          1048576.0  ...        1048576.0
m                              16  ...               16
k                              16  ...               16
n                              16  ...               16
compute_type                  C68  ...              C77
lda                            16  ...               16
type_a                        F32  ...              F32
ldb                            16  ...               16
t-gemm                   0.000071  ...         0.000085
type_b                        F32  ...              F32
t-workspace_new          0.000009  ...         0.000011
type_d                        F32  ...              F32
test                            0  ...                1
~dim                         16.0  ...             16.0
mnk                      16x16x16  ...         16x16x16
name              F32xF32->F32C68  ...  F32xF32->F32C77

[28 rows x 5 columns]
0 1 2 3 4
t-total 0.000257 0.00042 0.000469 0.0004 0.00024
t-clean 0.000002 0.000004 0.000003 0.000005 0.000005
t-gemm_in 0.000041 0.000076 0.000114 0.00007 0.000048
t-setup 0.000025 0.000046 0.00005 0.00004 0.00003
t-stream_create 0.0 0.0 0.0 0.0 0.0
N 80 80 80 40 80
epiloque 1.0 1.0 1.0 1.0 1.0
ldd 16 32 64 92 16
t-workspace_free 0.000017 0.00003 0.000014 0.000024 0.000016
algo 11.0 0.0 0.0 0.0 11.0
t-gemm_sync 0.000218 0.00035 0.000397 0.000342 0.000198
t-stream_destroy 0.000005 0.000018 0.000038 0.000017 0.000007
workspace_size 1048576.0 1048576.0 1048576.0 1048576.0 1048576.0
m 16 32 64 64 16
k 16 32 64 92 16
n 16 32 64 128 16
compute_type C68 C68 C68 C68 C77
lda 16 32 64 92 16
type_a F32 F32 F32 F32 F32
ldb 16 32 64 92 16
t-gemm 0.000071 0.000131 0.000171 0.000118 0.000085
type_b F32 F32 F32 F32 F32
t-workspace_new 0.000009 0.000012 0.000012 0.000011 0.000011
type_d F32 F32 F32 F32 F32
test 0 0 0 0 1
~dim 16.0 32.0 64.0 108.51728 16.0
mnk 16x16x16 32x32x32 64x64x64 64x128x92 16x16x16
name F32xF32->F32C68 F32xF32->F32C68 F32xF32->F32C68 F32xF32->F32C68 F32xF32->F32C77


Test definition#

col_def = ["name", "test", "type_a", "type_b", "type_d", "compute_type"]
if df.shape[0] > 0:
    deft = df.copy()
    gr = deft[col_def].groupby(col_def, as_index=False).count()
    print(gr)
                 name  test type_a type_b type_d compute_type
0  BF16xBF16->BF16C68     4   BF16   BF16   BF16          C68
1     F16xF16->F16C64     3    F16    F16    F16          C64
2     F32xF32->F32C68     0    F32    F32    F32          C68
3     F32xF32->F32C75     2    F32    F32    F32          C75
4     F32xF32->F32C77     1    F32    F32    F32          C77
5       I8xI8->I32C72    15     I8     I8    I32          C72

Total time and only gemm#

if df.shape[0] > 0:
    dfi = df[col_def + ["~dim", "mnk", "t-total", "t-gemm_sync"]]
    print(dfi)
                  name  test type_a  ...        mnk   t-total t-gemm_sync
0      F32xF32->F32C68     0    F32  ...   16x16x16  0.000257    0.000218
1      F32xF32->F32C68     0    F32  ...   32x32x32  0.000420    0.000350
2      F32xF32->F32C68     0    F32  ...   64x64x64  0.000469    0.000397
3      F32xF32->F32C68     0    F32  ...  64x128x92  0.000400    0.000342
4      F32xF32->F32C77     1    F32  ...   16x16x16  0.000240    0.000198
5      F32xF32->F32C77     1    F32  ...   32x32x32  0.000359    0.000307
6      F32xF32->F32C77     1    F32  ...   64x64x64  0.000216    0.000193
7      F32xF32->F32C77     1    F32  ...  64x128x92  0.000417    0.000378
8      F32xF32->F32C75     2    F32  ...   16x16x16  0.000339    0.000303
9      F32xF32->F32C75     2    F32  ...   32x32x32  0.000451    0.000407
10     F32xF32->F32C75     2    F32  ...   64x64x64  0.000385    0.000341
11     F32xF32->F32C75     2    F32  ...  64x128x92  0.000486    0.000430
12     F16xF16->F16C64     3    F16  ...   16x16x16  0.001203    0.001172
13     F16xF16->F16C64     3    F16  ...   32x32x32  0.003306    0.003290
14     F16xF16->F16C64     3    F16  ...   64x64x64  0.006266    0.006250
15     F16xF16->F16C64     3    F16  ...  64x128x92  0.009315    0.009286
16  BF16xBF16->BF16C68     4   BF16  ...   16x16x16  0.000211    0.000149
17  BF16xBF16->BF16C68     4   BF16  ...   32x32x32  0.000261    0.000245
18  BF16xBF16->BF16C68     4   BF16  ...   64x64x64  0.000374    0.000356
19  BF16xBF16->BF16C68     4   BF16  ...  64x128x92  0.000379    0.000363
20       I8xI8->I32C72    15     I8  ...   16x16x16  0.000147    0.000131
21       I8xI8->I32C72    15     I8  ...   32x32x32  0.000210    0.000192
22       I8xI8->I32C72    15     I8  ...   64x64x64  0.000305    0.000253
23       I8xI8->I32C72    15     I8  ...  64x128x92  0.000307    0.000290

[24 rows x 10 columns]

Smaller sets#

if df.shape[0] > 0:
    subset = {1, 3, 4, 5, 7}
    dfis = dfi[dfi.test.isin(subset)]
    print()
    print("t-gemm_sync")
    pivi = dfis.pivot_table(index=["~dim", "mnk"], columns="name", values="t-gemm_sync")
    print(pivi)
    print()
    print("t-total")
    pivi = dfis.pivot_table(index=["~dim", "mnk"], columns="name", values="t-total")
    print(pivi)
t-gemm_sync
name                 BF16xBF16->BF16C68  F16xF16->F16C64  F32xF32->F32C77
~dim      mnk
16.00000  16x16x16             0.000149         0.001172         0.000198
32.00000  32x32x32             0.000245         0.003290         0.000307
64.00000  64x64x64             0.000356         0.006250         0.000193
108.51728 64x128x92            0.000363         0.009286         0.000378

t-total
name                 BF16xBF16->BF16C68  F16xF16->F16C64  F32xF32->F32C77
~dim      mnk
16.00000  16x16x16             0.000211         0.001203         0.000240
32.00000  32x32x32             0.000261         0.003306         0.000359
64.00000  64x64x64             0.000374         0.006266         0.000216
108.51728 64x128x92            0.000379         0.009315         0.000417

Plots#

if df.shape[0] > 0:
    piv = df.pivot_table(index=["~dim", "mnk"], columns="name", values="t-gemm_sync")
    piv.plot(title="MatMul performances")

    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    piv.plot(ax=ax[0], title="Gemm performance\nlower is better", logx=True, logy=True)

    piv = df[df.test.isin(subset)].pivot_table(
        index=["~dim", "mnk"], columns="name", values="t-gemm_sync"
    )
    if piv.shape[0] > 0:
        piv.plot(
            ax=ax[1], title="Gemm performance\nlower is better", logx=True, logy=True
        )

    fig.tight_layout()
    fig.savefig("plot_bench_gemm_f8.png")
  • MatMul performances
  • Gemm performance lower is better, Gemm performance lower is better

Total running time of the script: (0 minutes 16.708 seconds)

Gallery generated by Sphinx-Gallery