Note
Go to the end to download the full example code.
Measuring Gemm performance with different input and output tests¶
This benchmark looks into various combinations allowed by functions cublasLtMatmul. The tested configurations are available at cuda_gemm.cu.
import pprint
import warnings
from itertools import product
from tqdm import tqdm
import matplotlib.pyplot as plt
from pandas import DataFrame
from onnx_extended.args import get_parsed_args
from onnx_extended.ext_test_case import unit_test_going
try:
from onnx_extended.validation.cuda.cuda_example_py import (
gemm_benchmark_test,
get_device_prop,
)
has_cuda = True
except ImportError:
# CUDA not available.
has_cuda = False
gemm_benchmark_test = None
if has_cuda:
prop = get_device_prop()
if prop["major"] <= 0:
# No CUDA.
dtests, ddims = "", ""
elif prop["major"] < 7:
# No float 8.
dtests, ddims = "0,1,2,3,4,15", "16,32,64,64x128x92"
elif prop["major"] < 9: # T100, A100
# No float 8.
dtests, ddims = (
"0,1,2,3,4,15",
"16,32,64,128,128x128x128,128x512x128,128x512x512",
)
else:
dtests, ddims = (
"0,1,2,3,4,5,6,7,11,14,15",
"16,32,64,128,256,512,1024,2048,4096,8192,16384,"
"128x768x768,128x3072x768,128x768x3072",
)
else:
dtests, ddims = "", ""
script_args = get_parsed_args(
"plot_bench_gemm_f8",
description=__doc__,
dims=(
"16,32" if unit_test_going() else ddims,
"square matrix dimensions to try, comma separated values",
),
tests=(
"0,1,2" if unit_test_going() else dtests,
"configuration to check, see cuda_gemm.cu",
),
warmup=2 if unit_test_going() else 5,
repeat=2 if unit_test_going() else 10,
expose="repeat,warmup",
)
Device¶
if has_cuda:
prop = get_device_prop()
pprint.pprint(prop)
else:
print("CUDA is not available")
prop = dict(major=0)
{'clockRate': 2010000,
'computeMode': 0,
'concurrentKernels': 1,
'isMultiGpuBoard': 0,
'major': 8,
'maxThreadsPerBlock': 1024,
'minor': 9,
'multiProcessorCount': 24,
'name': 'NVIDIA GeForce RTX 4060 Laptop GPU',
'sharedMemPerBlock': 49152,
'totalConstMem': 65536,
'totalGlobalMem': 8585281536}
Benchmark¶
def type2string(dt):
dtests = {
0: "F32",
2: "F16",
14: "BF16",
28: "E4M3",
29: "E5M2",
3: "I8",
10: "I32",
}
return dtests[int(dt)]
dims = []
tests = []
if gemm_benchmark_test is not None:
for d in script_args.dims.split(","):
if "x" in d:
spl = d.split("x")
m, n, k = tuple(int(i) for i in spl)
dims.append((m, n, k))
else:
dims.append(int(d))
tests = [int(i) for i in script_args.tests.split(",")]
pbar = tqdm(list(product(tests, dims)))
obs = []
for test, dim in pbar:
pbar.set_description(f"type={test} dim={dim}")
if test in {8, 9, 10, 12, 13}:
warnings.warn(f"unsupported configuration {test}.", stacklevel=0)
continue
mdim = dim if isinstance(dim, int) else max(dim)
if mdim < 128:
n, N = script_args.warmup * 8, script_args.repeat * 8
elif mdim < 512:
n, N = script_args.warmup * 4, script_args.repeat * 4
elif mdim < 8192:
n, N = script_args.warmup * 2, script_args.repeat * 2
else:
n, N = script_args.warmup, script_args.repeat
if isinstance(dim, int):
gemm_args = [dim] * 6
else:
m, n, k = dim
lda, ldb, ldd = k, k, k
gemm_args = [m, n, k, lda, ldb, ldd]
# warmup
try:
gemm_benchmark_test(test, N, *gemm_args)
except RuntimeError:
# Not working.
continue
# benchmark
res = gemm_benchmark_test(test, N, *gemm_args)
# better rendering
res["test"] = test
update = {}
for k, v in res.items():
if "type_" in k:
update[k] = type2string(v)
if k.startswith("t-"):
update[k] = res[k] / res["N"]
update["compute_type"] = f"C{int(res['compute_type'])}"
for c in ["N", "m", "n", "k", "lda", "ldb", "ldd"]:
update[c] = int(res[c])
update["~dim"] = (update["k"] * max(update["m"], update["n"])) ** 0.5
update["mnk"] = f"{update['m']}x{update['n']}x{update['k']}"
update["name"] = (
f"{update['type_a']}x{update['type_b']}->"
f"{update['type_d']}{update['compute_type']}"
)
res.update(update)
obs.append(res)
if unit_test_going() and len(obs) > 2:
break
df = DataFrame(obs)
df.to_csv("plot_bench_gemm_f8.csv", index=False)
df.to_excel("plot_bench_gemm_f8.xlsx", index=False)
print(df.head().T)
df.head().T
0%| | 0/42 [00:00<?, ?it/s]
type=0 dim=16: 0%| | 0/42 [00:00<?, ?it/s]
type=0 dim=16: 2%|▏ | 1/42 [00:01<01:17, 1.88s/it]
type=0 dim=32: 2%|▏ | 1/42 [00:01<01:17, 1.88s/it]
type=0 dim=64: 2%|▏ | 1/42 [00:01<01:17, 1.88s/it]
type=0 dim=128: 2%|▏ | 1/42 [00:01<01:17, 1.88s/it]
type=0 dim=(128, 128, 128): 2%|▏ | 1/42 [00:01<01:17, 1.88s/it]
type=0 dim=(128, 512, 128): 2%|▏ | 1/42 [00:01<01:17, 1.88s/it]
type=0 dim=(128, 512, 512): 2%|▏ | 1/42 [00:01<01:17, 1.88s/it]
type=1 dim=16: 2%|▏ | 1/42 [00:01<01:17, 1.88s/it]
type=1 dim=32: 2%|▏ | 1/42 [00:01<01:17, 1.88s/it]
type=1 dim=64: 2%|▏ | 1/42 [00:01<01:17, 1.88s/it]
type=1 dim=128: 2%|▏ | 1/42 [00:01<01:17, 1.88s/it]
type=1 dim=(128, 128, 128): 2%|▏ | 1/42 [00:01<01:17, 1.88s/it]
type=1 dim=(128, 128, 128): 29%|██▊ | 12/42 [00:01<00:03, 8.23it/s]
type=1 dim=(128, 512, 128): 29%|██▊ | 12/42 [00:01<00:03, 8.23it/s]
type=1 dim=(128, 512, 512): 29%|██▊ | 12/42 [00:01<00:03, 8.23it/s]
type=2 dim=16: 29%|██▊ | 12/42 [00:02<00:03, 8.23it/s]
type=2 dim=32: 29%|██▊ | 12/42 [00:02<00:03, 8.23it/s]
type=2 dim=64: 29%|██▊ | 12/42 [00:02<00:03, 8.23it/s]
type=2 dim=128: 29%|██▊ | 12/42 [00:02<00:03, 8.23it/s]
type=2 dim=(128, 128, 128): 29%|██▊ | 12/42 [00:02<00:03, 8.23it/s]
type=2 dim=(128, 512, 128): 29%|██▊ | 12/42 [00:02<00:03, 8.23it/s]
type=2 dim=(128, 512, 512): 29%|██▊ | 12/42 [00:02<00:03, 8.23it/s]
type=2 dim=(128, 512, 512): 50%|█████ | 21/42 [00:02<00:01, 15.24it/s]
type=3 dim=16: 50%|█████ | 21/42 [00:02<00:01, 15.24it/s]
type=3 dim=32: 50%|█████ | 21/42 [00:02<00:01, 15.24it/s]
type=3 dim=64: 50%|█████ | 21/42 [00:02<00:01, 15.24it/s]
type=3 dim=128: 50%|█████ | 21/42 [00:02<00:01, 15.24it/s]
type=3 dim=(128, 128, 128): 50%|█████ | 21/42 [00:02<00:01, 15.24it/s]
type=3 dim=(128, 512, 128): 50%|█████ | 21/42 [00:02<00:01, 15.24it/s]
type=3 dim=(128, 512, 512): 50%|█████ | 21/42 [00:02<00:01, 15.24it/s]
type=4 dim=16: 50%|█████ | 21/42 [00:02<00:01, 15.24it/s]
type=4 dim=16: 69%|██████▉ | 29/42 [00:02<00:00, 19.87it/s]
type=4 dim=32: 69%|██████▉ | 29/42 [00:02<00:00, 19.87it/s]
type=4 dim=64: 69%|██████▉ | 29/42 [00:02<00:00, 19.87it/s]
type=4 dim=128: 69%|██████▉ | 29/42 [00:02<00:00, 19.87it/s]
type=4 dim=(128, 128, 128): 69%|██████▉ | 29/42 [00:02<00:00, 19.87it/s]
type=4 dim=(128, 512, 128): 69%|██████▉ | 29/42 [00:02<00:00, 19.87it/s]
type=4 dim=(128, 512, 512): 69%|██████▉ | 29/42 [00:02<00:00, 19.87it/s]
type=15 dim=16: 69%|██████▉ | 29/42 [00:02<00:00, 19.87it/s]
type=15 dim=32: 69%|██████▉ | 29/42 [00:02<00:00, 19.87it/s]
type=15 dim=64: 69%|██████▉ | 29/42 [00:02<00:00, 19.87it/s]
type=15 dim=128: 69%|██████▉ | 29/42 [00:02<00:00, 19.87it/s]
type=15 dim=(128, 128, 128): 69%|██████▉ | 29/42 [00:02<00:00, 19.87it/s]
type=15 dim=(128, 512, 128): 69%|██████▉ | 29/42 [00:02<00:00, 19.87it/s]
type=15 dim=(128, 512, 512): 69%|██████▉ | 29/42 [00:02<00:00, 19.87it/s]
type=15 dim=(128, 512, 512): 100%|██████████| 42/42 [00:02<00:00, 17.44it/s]
0 1 2 3 4
t-total 0.000052 0.000047 0.000044 0.000042 0.000043
t-clean 0.0 0.0 0.0 0.0 0.0
t-gemm_in 0.000009 0.00001 0.00001 0.000007 0.000008
t-setup 0.000002 0.000001 0.000001 0.000002 0.000002
t-stream_create 0.0 0.0 0.0 0.0 0.0
N 80 80 80 40 40
epiloque 1.0 1.0 1.0 1.0 1.0
ldd 16 32 64 128 128
t-workspace_free 0.000002 0.000002 0.000002 0.000002 0.000002
algo 11.0 1.0 1.0 1.0 1.0
t-gemm_sync 0.000045 0.000042 0.000039 0.000037 0.000036
t-stream_destroy 0.000001 0.000001 0.000001 0.000001 0.000001
workspace_size 1048576.0 1048576.0 1048576.0 1048576.0 1048576.0
m 16 32 64 128 128
k 16 32 64 128 128
n 16 32 64 128 128
compute_type C68 C68 C68 C68 C68
lda 16 32 64 128 128
type_a F32 F32 F32 F32 F32
ldb 16 32 64 128 128
t-gemm 0.000011 0.000012 0.000012 0.000009 0.00001
type_b F32 F32 F32 F32 F32
t-workspace_new 0.000002 0.000002 0.000002 0.000002 0.000002
type_d F32 F32 F32 F32 F32
test 0 0 0 0 0
~dim 16.0 32.0 64.0 128.0 128.0
mnk 16x16x16 32x32x32 64x64x64 128x128x128 128x128x128
name F32xF32->F32C68 F32xF32->F32C68 F32xF32->F32C68 F32xF32->F32C68 F32xF32->F32C68
Test definition¶
name test type_a type_b type_d compute_type
0 BF16xBF16->BF16C68 4 BF16 BF16 BF16 C68
1 F32xF32->F32C68 0 F32 F32 F32 C68
2 F32xF32->F32C75 2 F32 F32 F32 C75
3 F32xF32->F32C77 1 F32 F32 F32 C77
Total time and only gemm¶
name test type_a type_b type_d compute_type ~dim mnk t-total t-gemm_sync
0 F32xF32->F32C68 0 F32 F32 F32 C68 16.0 16x16x16 0.000052 0.000045
1 F32xF32->F32C68 0 F32 F32 F32 C68 32.0 32x32x32 0.000047 0.000042
2 F32xF32->F32C68 0 F32 F32 F32 C68 64.0 64x64x64 0.000044 0.000039
3 F32xF32->F32C68 0 F32 F32 F32 C68 128.0 128x128x128 0.000042 0.000037
4 F32xF32->F32C68 0 F32 F32 F32 C68 128.0 128x128x128 0.000043 0.000036
5 F32xF32->F32C68 0 F32 F32 F32 C68 256.0 128x512x128 0.000044 0.000038
6 F32xF32->F32C68 0 F32 F32 F32 C68 512.0 128x512x512 0.000577 0.000080
7 F32xF32->F32C77 1 F32 F32 F32 C77 16.0 16x16x16 0.000044 0.000038
8 F32xF32->F32C77 1 F32 F32 F32 C77 32.0 32x32x32 0.000045 0.000039
9 F32xF32->F32C77 1 F32 F32 F32 C77 64.0 64x64x64 0.000044 0.000038
10 F32xF32->F32C77 1 F32 F32 F32 C77 128.0 128x128x128 0.000042 0.000036
11 F32xF32->F32C77 1 F32 F32 F32 C77 128.0 128x128x128 0.000050 0.000045
12 F32xF32->F32C77 1 F32 F32 F32 C77 256.0 128x512x128 0.000068 0.000057
13 F32xF32->F32C77 1 F32 F32 F32 C77 512.0 128x512x512 0.000691 0.000101
14 F32xF32->F32C75 2 F32 F32 F32 C75 16.0 16x16x16 0.000046 0.000039
15 F32xF32->F32C75 2 F32 F32 F32 C75 32.0 32x32x32 0.000110 0.000095
16 F32xF32->F32C75 2 F32 F32 F32 C75 64.0 64x64x64 0.000055 0.000047
17 F32xF32->F32C75 2 F32 F32 F32 C75 128.0 128x128x128 0.000049 0.000042
18 F32xF32->F32C75 2 F32 F32 F32 C75 128.0 128x128x128 0.000045 0.000039
19 F32xF32->F32C75 2 F32 F32 F32 C75 256.0 128x512x128 0.000060 0.000053
20 F32xF32->F32C75 2 F32 F32 F32 C75 512.0 128x512x512 0.000631 0.000111
21 BF16xBF16->BF16C68 4 BF16 BF16 BF16 C68 16.0 16x16x16 0.000038 0.000033
22 BF16xBF16->BF16C68 4 BF16 BF16 BF16 C68 32.0 32x32x32 0.000041 0.000035
23 BF16xBF16->BF16C68 4 BF16 BF16 BF16 C68 64.0 64x64x64 0.000038 0.000033
24 BF16xBF16->BF16C68 4 BF16 BF16 BF16 C68 128.0 128x128x128 0.000053 0.000044
25 BF16xBF16->BF16C68 4 BF16 BF16 BF16 C68 128.0 128x128x128 0.000042 0.000036
26 BF16xBF16->BF16C68 4 BF16 BF16 BF16 C68 256.0 128x512x128 0.000046 0.000040
27 BF16xBF16->BF16C68 4 BF16 BF16 BF16 C68 512.0 128x512x512 0.000054 0.000045
Smaller sets¶
if df.shape[0] > 0:
subset = {1, 3, 4, 5, 7}
dfis = dfi[dfi.test.isin(subset)]
print()
print("t-gemm_sync")
pivi = dfis.pivot_table(index=["~dim", "mnk"], columns="name", values="t-gemm_sync")
print(pivi)
print()
print("t-total")
pivi = dfis.pivot_table(index=["~dim", "mnk"], columns="name", values="t-total")
print(pivi)
t-gemm_sync
name BF16xBF16->BF16C68 F32xF32->F32C77
~dim mnk
16.0 16x16x16 0.000033 0.000038
32.0 32x32x32 0.000035 0.000039
64.0 64x64x64 0.000033 0.000038
128.0 128x128x128 0.000040 0.000041
256.0 128x512x128 0.000040 0.000057
512.0 128x512x512 0.000045 0.000101
t-total
name BF16xBF16->BF16C68 F32xF32->F32C77
~dim mnk
16.0 16x16x16 0.000038 0.000044
32.0 32x32x32 0.000041 0.000045
64.0 64x64x64 0.000038 0.000044
128.0 128x128x128 0.000048 0.000046
256.0 128x512x128 0.000046 0.000068
512.0 128x512x512 0.000054 0.000691
Plots¶
if df.shape[0] > 0:
piv = df.pivot_table(index=["~dim", "mnk"], columns="name", values="t-gemm_sync")
piv.plot(title="MatMul performances")
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
piv.plot(ax=ax[0], title="Gemm performance\nlower is better", logx=True, logy=True)
piv = df[df.test.isin(subset)].pivot_table(
index=["~dim", "mnk"], columns="name", values="t-gemm_sync"
)
if piv.shape[0] > 0:
piv.plot(
ax=ax[1], title="Gemm performance\nlower is better", logx=True, logy=True
)
fig.tight_layout()
fig.savefig("plot_bench_gemm_f8.png")
Total running time of the script: (0 minutes 5.987 seconds)