Note
Go to the end to download the full example code
Measuring Gemm performance with different input and output tests¶
This benchmark looks into various combinations allowed by functions cublasLtMatmul. The tested configurations are available at cuda_gemm.cu.
import pprint
import warnings
from itertools import product
from tqdm import tqdm
import matplotlib.pyplot as plt
from pandas import DataFrame
from onnx_extended.args import get_parsed_args
from onnx_extended.ext_test_case import unit_test_going
try:
from onnx_extended.validation.cuda.cuda_example_py import (
gemm_benchmark_test,
get_device_prop,
)
has_cuda = True
except ImportError:
# CUDA not available.
has_cuda = False
gemm_benchmark_test = None
if has_cuda:
prop = get_device_prop()
if prop["major"] <= 0:
# No CUDA.
dtests, ddims = "", ""
elif prop["major"] < 7:
# No float 8.
dtests, ddims = "0,1,2,3,4,15", "16,32,64,64x128x92"
elif prop["major"] < 9: # T100, A100
# No float 8.
dtests, ddims = (
"0,1,2,3,4,15",
"16,32,64,128,256,512,1024,2048,4096,8192,"
"128x768x768,128x3072x768,128x768x3072",
)
else:
dtests, ddims = (
"0,1,2,3,4,5,6,7,11,14,15",
"16,32,64,128,256,512,1024,2048,4096,8192,16384,"
"128x768x768,128x3072x768,128x768x3072",
)
else:
dtests, ddims = "", ""
script_args = get_parsed_args(
"plot_bench_gemm_f8",
description=__doc__,
dims=(
"16,32" if unit_test_going() else ddims,
"square matrix dimensions to try, comma separated values",
),
tests=(
"0,1,2" if unit_test_going() else dtests,
"configuration to check, see cuda_gemm.cu",
),
warmup=2 if unit_test_going() else 5,
repeat=2 if unit_test_going() else 10,
expose="repeat,warmup",
)
Device¶
if has_cuda:
prop = get_device_prop()
pprint.pprint(prop)
else:
print("CUDA is not available")
prop = dict(major=0)
{'clockRate': 1569000,
'computeMode': 0,
'concurrentKernels': 1,
'isMultiGpuBoard': 0,
'major': 6,
'maxThreadsPerBlock': 1024,
'minor': 1,
'multiProcessorCount': 10,
'name': 'NVIDIA GeForce GTX 1060',
'sharedMemPerBlock': 49152,
'totalConstMem': 65536,
'totalGlobalMem': 6442319872}
Benchmark¶
def type2string(dt):
dtests = {
0: "F32",
2: "F16",
14: "BF16",
28: "E4M3",
29: "E5M2",
3: "I8",
10: "I32",
}
return dtests[int(dt)]
dims = []
tests = []
if gemm_benchmark_test is not None:
for d in script_args.dims.split(","):
if "x" in d:
spl = d.split("x")
m, n, k = tuple(int(i) for i in spl)
dims.append((m, n, k))
else:
dims.append(int(d))
tests = list(int(i) for i in script_args.tests.split(","))
pbar = tqdm(list(product(tests, dims)))
obs = []
for test, dim in pbar:
pbar.set_description(f"type={test} dim={dim}")
if test in {8, 9, 10, 12, 13}:
warnings.warn(f"unsupported configuration {test}.")
continue
mdim = dim if isinstance(dim, int) else max(dim)
if mdim < 128:
n, N = script_args.warmup * 8, script_args.repeat * 8
elif mdim < 512:
n, N = script_args.warmup * 4, script_args.repeat * 4
elif mdim < 8192:
n, N = script_args.warmup * 2, script_args.repeat * 2
else:
n, N = script_args.warmup, script_args.repeat
if isinstance(dim, int):
gemm_args = [dim] * 6
else:
m, n, k = dim
lda, ldb, ldd = k, k, k
gemm_args = [m, n, k, lda, ldb, ldd]
# warmup
gemm_benchmark_test(test, N, *gemm_args)
# benchmark
res = gemm_benchmark_test(test, N, *gemm_args)
# better rendering
res["test"] = test
update = {}
for k, v in res.items():
if "type_" in k:
update[k] = type2string(v)
if k.startswith("t-"):
update[k] = res[k] / res["N"]
update["compute_type"] = f"C{int(res['compute_type'])}"
for c in ["N", "m", "n", "k", "lda", "ldb", "ldd"]:
update[c] = int(res[c])
update["~dim"] = (update["k"] * max(update["m"], update["n"])) ** 0.5
update["mnk"] = f"{update['m']}x{update['n']}x{update['k']}"
update["name"] = (
f"{update['type_a']}x{update['type_b']}->"
f"{update['type_d']}{update['compute_type']}"
)
res.update(update)
obs.append(res)
if unit_test_going() and len(obs) > 2:
break
df = DataFrame(obs)
df.to_csv("plot_bench_gemm_f8.csv", index=False)
df.to_excel("plot_bench_gemm_f8.xlsx", index=False)
print(df.head().T)
df.head().T
0%| | 0/24 [00:00<?, ?it/s]
type=0 dim=16: 0%| | 0/24 [00:00<?, ?it/s]
type=0 dim=16: 4%|▍ | 1/24 [00:01<00:37, 1.61s/it]
type=0 dim=32: 4%|▍ | 1/24 [00:01<00:37, 1.61s/it]
type=0 dim=64: 4%|▍ | 1/24 [00:01<00:37, 1.61s/it]
type=0 dim=64: 12%|█▎ | 3/24 [00:01<00:09, 2.19it/s]
type=0 dim=(64, 128, 92): 12%|█▎ | 3/24 [00:01<00:09, 2.19it/s]
type=1 dim=16: 12%|█▎ | 3/24 [00:01<00:09, 2.19it/s]
type=1 dim=32: 12%|█▎ | 3/24 [00:01<00:09, 2.19it/s]
type=1 dim=32: 25%|██▌ | 6/24 [00:01<00:03, 4.84it/s]
type=1 dim=64: 25%|██▌ | 6/24 [00:01<00:03, 4.84it/s]
type=1 dim=(64, 128, 92): 25%|██▌ | 6/24 [00:01<00:03, 4.84it/s]
type=1 dim=(64, 128, 92): 33%|███▎ | 8/24 [00:01<00:02, 6.68it/s]
type=2 dim=16: 33%|███▎ | 8/24 [00:01<00:02, 6.68it/s]
type=2 dim=32: 33%|███▎ | 8/24 [00:02<00:02, 6.68it/s]
type=2 dim=32: 42%|████▏ | 10/24 [00:02<00:01, 8.67it/s]
type=2 dim=64: 42%|████▏ | 10/24 [00:02<00:01, 8.67it/s]
type=2 dim=(64, 128, 92): 42%|████▏ | 10/24 [00:02<00:01, 8.67it/s]
type=2 dim=(64, 128, 92): 50%|█████ | 12/24 [00:02<00:01, 10.52it/s]
type=3 dim=16: 50%|█████ | 12/24 [00:02<00:01, 10.52it/s]
type=3 dim=32: 50%|█████ | 12/24 [00:02<00:01, 10.52it/s]
type=3 dim=32: 58%|█████▊ | 14/24 [00:02<00:01, 5.44it/s]
type=3 dim=64: 58%|█████▊ | 14/24 [00:02<00:01, 5.44it/s]
type=3 dim=(64, 128, 92): 58%|█████▊ | 14/24 [00:03<00:01, 5.44it/s]
type=3 dim=(64, 128, 92): 67%|██████▋ | 16/24 [00:04<00:03, 2.40it/s]
type=4 dim=16: 67%|██████▋ | 16/24 [00:04<00:03, 2.40it/s]
type=4 dim=32: 67%|██████▋ | 16/24 [00:04<00:03, 2.40it/s]
type=4 dim=64: 67%|██████▋ | 16/24 [00:04<00:03, 2.40it/s]
type=4 dim=64: 79%|███████▉ | 19/24 [00:04<00:01, 3.70it/s]
type=4 dim=(64, 128, 92): 79%|███████▉ | 19/24 [00:04<00:01, 3.70it/s]
type=15 dim=16: 79%|███████▉ | 19/24 [00:04<00:01, 3.70it/s]
type=15 dim=32: 79%|███████▉ | 19/24 [00:05<00:01, 3.70it/s]
type=15 dim=32: 92%|█████████▏| 22/24 [00:05<00:00, 5.30it/s]
type=15 dim=64: 92%|█████████▏| 22/24 [00:05<00:00, 5.30it/s]
type=15 dim=(64, 128, 92): 92%|█████████▏| 22/24 [00:05<00:00, 5.30it/s]
type=15 dim=(64, 128, 92): 100%|██████████| 24/24 [00:05<00:00, 4.66it/s]
0 1 2 3 4
t-total 0.000143 0.000132 0.000149 0.000183 0.000093
t-clean 0.000001 0.000001 0.000001 0.000001 0.000001
t-gemm_in 0.000027 0.000011 0.000009 0.000009 0.000008
t-setup 0.000008 0.000012 0.000009 0.000009 0.000005
t-stream_create 0.0 0.0 0.0 0.0 0.0
N 80 80 80 40 80
epiloque 1.0 1.0 1.0 1.0 1.0
ldd 16 32 64 92 16
t-workspace_free 0.000005 0.000004 0.000003 0.000003 0.000003
algo 11.0 0.0 0.0 0.0 11.0
t-gemm_sync 0.000095 0.00012 0.00014 0.000174 0.000084
t-stream_destroy 0.000036 0.000002 0.000001 0.000001 0.000001
workspace_size 1048576.0 1048576.0 1048576.0 1048576.0 1048576.0
m 16 32 64 64 16
k 16 32 64 92 16
n 16 32 64 128 16
compute_type C68 C68 C68 C68 C77
lda 16 32 64 92 16
type_a F32 F32 F32 F32 F32
ldb 16 32 64 92 16
t-gemm 0.000037 0.000025 0.00002 0.00002 0.000015
type_b F32 F32 F32 F32 F32
t-workspace_new 0.000003 0.000004 0.000003 0.000002 0.000002
type_d F32 F32 F32 F32 F32
test 0 0 0 0 1
~dim 16.0 32.0 64.0 108.51728 16.0
mnk 16x16x16 32x32x32 64x64x64 64x128x92 16x16x16
name F32xF32->F32C68 F32xF32->F32C68 F32xF32->F32C68 F32xF32->F32C68 F32xF32->F32C77
Test definition¶
name test type_a type_b type_d compute_type
0 BF16xBF16->BF16C68 4 BF16 BF16 BF16 C68
1 F16xF16->F16C64 3 F16 F16 F16 C64
2 F32xF32->F32C68 0 F32 F32 F32 C68
3 F32xF32->F32C75 2 F32 F32 F32 C75
4 F32xF32->F32C77 1 F32 F32 F32 C77
5 I8xI8->I32C72 15 I8 I8 I32 C72
Total time and only gemm¶
name test type_a type_b type_d compute_type ~dim mnk t-total t-gemm_sync
0 F32xF32->F32C68 0 F32 F32 F32 C68 16.00000 16x16x16 0.000143 0.000095
1 F32xF32->F32C68 0 F32 F32 F32 C68 32.00000 32x32x32 0.000132 0.000120
2 F32xF32->F32C68 0 F32 F32 F32 C68 64.00000 64x64x64 0.000149 0.000140
3 F32xF32->F32C68 0 F32 F32 F32 C68 108.51728 64x128x92 0.000183 0.000174
4 F32xF32->F32C77 1 F32 F32 F32 C77 16.00000 16x16x16 0.000093 0.000084
5 F32xF32->F32C77 1 F32 F32 F32 C77 32.00000 32x32x32 0.000154 0.000139
6 F32xF32->F32C77 1 F32 F32 F32 C77 64.00000 64x64x64 0.000159 0.000149
7 F32xF32->F32C77 1 F32 F32 F32 C77 108.51728 64x128x92 0.000194 0.000183
8 F32xF32->F32C75 2 F32 F32 F32 C75 16.00000 16x16x16 0.000090 0.000081
9 F32xF32->F32C75 2 F32 F32 F32 C75 32.00000 32x32x32 0.000170 0.000152
10 F32xF32->F32C75 2 F32 F32 F32 C75 64.00000 64x64x64 0.000158 0.000148
11 F32xF32->F32C75 2 F32 F32 F32 C75 108.51728 64x128x92 0.000195 0.000184
12 F16xF16->F16C64 3 F16 F16 F16 C64 16.00000 16x16x16 0.001053 0.001020
13 F16xF16->F16C64 3 F16 F16 F16 C64 32.00000 32x32x32 0.003493 0.003417
14 F16xF16->F16C64 3 F16 F16 F16 C64 64.00000 64x64x64 0.006469 0.006414
15 F16xF16->F16C64 3 F16 F16 F16 C64 108.51728 64x128x92 0.009382 0.009356
16 BF16xBF16->BF16C68 4 BF16 BF16 BF16 C68 16.00000 16x16x16 0.000115 0.000106
17 BF16xBF16->BF16C68 4 BF16 BF16 BF16 C68 32.00000 32x32x32 0.000237 0.000228
18 BF16xBF16->BF16C68 4 BF16 BF16 BF16 C68 64.00000 64x64x64 0.000296 0.000286
19 BF16xBF16->BF16C68 4 BF16 BF16 BF16 C68 108.51728 64x128x92 0.000401 0.000385
20 I8xI8->I32C72 15 I8 I8 I32 C72 16.00000 16x16x16 0.000112 0.000103
21 I8xI8->I32C72 15 I8 I8 I32 C72 32.00000 32x32x32 0.000191 0.000180
22 I8xI8->I32C72 15 I8 I8 I32 C72 64.00000 64x64x64 0.000187 0.000178
23 I8xI8->I32C72 15 I8 I8 I32 C72 108.51728 64x128x92 0.000249 0.000203
Smaller sets¶
if df.shape[0] > 0:
subset = {1, 3, 4, 5, 7}
dfis = dfi[dfi.test.isin(subset)]
print()
print("t-gemm_sync")
pivi = dfis.pivot_table(index=["~dim", "mnk"], columns="name", values="t-gemm_sync")
print(pivi)
print()
print("t-total")
pivi = dfis.pivot_table(index=["~dim", "mnk"], columns="name", values="t-total")
print(pivi)
t-gemm_sync
name BF16xBF16->BF16C68 F16xF16->F16C64 F32xF32->F32C77
~dim mnk
16.00000 16x16x16 0.000106 0.001020 0.000084
32.00000 32x32x32 0.000228 0.003417 0.000139
64.00000 64x64x64 0.000286 0.006414 0.000149
108.51728 64x128x92 0.000385 0.009356 0.000183
t-total
name BF16xBF16->BF16C68 F16xF16->F16C64 F32xF32->F32C77
~dim mnk
16.00000 16x16x16 0.000115 0.001053 0.000093
32.00000 32x32x32 0.000237 0.003493 0.000154
64.00000 64x64x64 0.000296 0.006469 0.000159
108.51728 64x128x92 0.000401 0.009382 0.000194
Plots¶
if df.shape[0] > 0:
piv = df.pivot_table(index=["~dim", "mnk"], columns="name", values="t-gemm_sync")
piv.plot(title="MatMul performances")
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
piv.plot(ax=ax[0], title="Gemm performance\nlower is better", logx=True, logy=True)
piv = df[df.test.isin(subset)].pivot_table(
index=["~dim", "mnk"], columns="name", values="t-gemm_sync"
)
if piv.shape[0] > 0:
piv.plot(
ax=ax[1], title="Gemm performance\nlower is better", logx=True, logy=True
)
fig.tight_layout()
fig.savefig("plot_bench_gemm_f8.png")
Total running time of the script: (0 minutes 6.638 seconds)