Note
Go to the end to download the full example code
Measuring performance of TfIdfVectorizer#
The banchmark measures the performance of a TfIdfVectizer along two parameters, the vocabulary size, the batch size whether. It measures the benefit of using sparse implementation through the computation time and the memory peak.
A simple model#
We start with a model including only one node TfIdfVectorizer. It only contains unigram. The model processes only sequences of 10 integers. The sparsity of the results is then 10 divided by the size of vocabulary.
import gc
import time
import itertools
from typing import Tuple
import numpy as np
import pandas
from onnx import ModelProto
from onnx.helper import make_attribute
from tqdm import tqdm
from onnxruntime import InferenceSession, SessionOptions
from onnx_extended.ext_test_case import measure_time, unit_test_going
from onnx_extended.memory_peak import start_spying_on
from onnx_extended.reference import CReferenceEvaluator
from onnx_extended.ortops.optim.cpu import get_ort_ext_libs
from onnx_extended.plotting.benchmark import vhistograms
def make_onnx(n_words: int) -> ModelProto:
from skl2onnx.common.data_types import Int64TensorType, FloatTensorType
from skl2onnx.algebra.onnx_ops import OnnxTfIdfVectorizer
# from onnx_array_api.light_api import start
# onx = (
# start(opset=19, opsets={"ai.onnx.ml": 3})
# .vin("X", elem_type=TensorProto.INT64)
# .ai.onnx.TfIdfVectorizer(
# ...
# )
# .rename(Y)
# .vout(elem_type=TensorProto.FLOAT)
# .to_onnx()
# )
onx = OnnxTfIdfVectorizer(
"X",
mode="TF",
min_gram_length=1,
max_gram_length=1,
max_skip_count=0,
ngram_counts=[0],
ngram_indexes=np.arange(n_words).tolist(),
pool_int64s=np.arange(n_words).tolist(),
output_names=["Y"],
).to_onnx(inputs=[("X", Int64TensorType())], outputs=[("Y", FloatTensorType())])
# .rename(Y)
# .vout(elem_type=TensorProto.FLOAT)
# .to_onnx()
# )
return onx
onx = make_onnx(7)
ref = CReferenceEvaluator(onx)
got = ref.run(None, {"X": np.array([[0, 1], [2, 3]], dtype=np.int64)})
print(got)
[array([[1., 1., 0., 0., 0., 0., 0.],
[0., 0., 1., 1., 0., 0., 0.]], dtype=float32)]
It works as expected. Let’s now compare the execution with onnxruntime for different batch size and vocabulary size.
Benchmark#
def make_sessions(
onx: ModelProto,
) -> Tuple[InferenceSession, InferenceSession, InferenceSession]:
# first: onnxruntime
ref = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
# second: custom kernel equivalent to the onnxruntime implementation
for node in onx.graph.node:
if node.op_type == "TfIdfVectorizer":
node.domain = "onnx_extented.ortops.optim.cpu"
# new_add = make_attribute("sparse", 1)
# node.attribute.append(new_add)
d = onx.opset_import.add()
d.domain = "onnx_extented.ortops.optim.cpu"
d.version = 1
r = get_ort_ext_libs()
opts = SessionOptions()
opts.register_custom_ops_library(r[0])
cus = InferenceSession(
onx.SerializeToString(), opts, providers=["CPUExecutionProvider"]
)
# third: with sparse
for node in onx.graph.node:
if node.op_type == "TfIdfVectorizer":
new_add = make_attribute("sparse", 1)
node.attribute.append(new_add)
cussp = InferenceSession(
onx.SerializeToString(), opts, providers=["CPUExecutionProvider"]
)
return ref, cus, cussp
if unit_test_going():
vocabulary_sizes = [10, 20]
batch_sizes = [5, 10]
else:
vocabulary_sizes = [100, 1000, 5000, 10000]
batch_sizes = [1, 10, 500, 1000, 2000]
confs = list(itertools.product(vocabulary_sizes, batch_sizes))
data = []
for voc_size, batch_size in tqdm(confs):
onx = make_onnx(voc_size)
ref, cus, sparse = make_sessions(onx)
gc.collect()
feeds = dict(
X=(np.arange(batch_size * 10) % voc_size)
.reshape((batch_size, -1))
.astype(np.int64)
)
# sparse
p = start_spying_on(delay=0.0001)
sparse.run(None, feeds)
obs = measure_time(
lambda sparse=sparse, feeds=feeds: sparse.run(None, feeds), max_time=1
)
mem = p.stop()
obs["peak"] = mem["cpu"].max_peak - mem["cpu"].begin
obs["name"] = "sparse"
obs.update(dict(voc_size=voc_size, batch_size=batch_size))
data.append(obs)
time.sleep(0.1)
# reference
p = start_spying_on(delay=0.0001)
ref.run(None, feeds)
obs = measure_time(lambda ref=ref, feeds=feeds: ref.run(None, feeds), max_time=1)
mem = p.stop()
obs["peak"] = mem["cpu"].max_peak - mem["cpu"].begin
obs["name"] = "ref"
obs.update(dict(voc_size=voc_size, batch_size=batch_size))
data.append(obs)
time.sleep(0.1)
# custom
p = start_spying_on(delay=0.0001)
cus.run(None, feeds)
obs = measure_time(lambda cus=cus, feeds=feeds: cus.run(None, feeds), max_time=1)
mem = p.stop()
obs["peak"] = mem["cpu"].max_peak - mem["cpu"].begin
obs["name"] = "custom"
obs.update(dict(voc_size=voc_size, batch_size=batch_size))
data.append(obs)
time.sleep(0.1)
del sparse
del cus
del ref
del feeds
df = pandas.DataFrame(data)
df["time"] = df["average"]
df.to_csv("plot_op_tfidfvectorizer_sparse.csv", index=False)
print(df.head())
0%| | 0/20 [00:00<?, ?it/s]
5%|▌ | 1/20 [00:04<01:20, 4.25s/it]
10%|█ | 2/20 [00:09<01:22, 4.56s/it]
15%|█▌ | 3/20 [00:13<01:14, 4.40s/it]
20%|██ | 4/20 [00:17<01:09, 4.34s/it]
25%|██▌ | 5/20 [00:21<01:04, 4.29s/it]
30%|███ | 6/20 [00:25<00:59, 4.23s/it]
35%|███▌ | 7/20 [00:30<00:58, 4.47s/it]
40%|████ | 8/20 [00:35<00:53, 4.42s/it]
45%|████▌ | 9/20 [00:39<00:49, 4.50s/it]
50%|█████ | 10/20 [00:44<00:44, 4.42s/it]
55%|█████▌ | 11/20 [00:48<00:39, 4.35s/it]
60%|██████ | 12/20 [00:52<00:34, 4.36s/it]
65%|██████▌ | 13/20 [00:56<00:30, 4.36s/it]
70%|███████ | 14/20 [01:01<00:26, 4.41s/it]
75%|███████▌ | 15/20 [01:06<00:22, 4.47s/it]
80%|████████ | 16/20 [01:10<00:17, 4.32s/it]
85%|████████▌ | 17/20 [01:14<00:13, 4.37s/it]
90%|█████████ | 18/20 [01:19<00:08, 4.44s/it]
95%|█████████▌| 19/20 [01:23<00:04, 4.49s/it]
100%|██████████| 20/20 [01:28<00:00, 4.58s/it]
100%|██████████| 20/20 [01:28<00:00, 4.43s/it]
average deviation min_exec max_exec repeat number ttime context_size warmup_time peak name voc_size batch_size time
0 0.000012 5.508292e-07 0.000012 0.000078 1 97535.0 1.203585 64 0.000222 12288 sparse 100 1 0.000012
1 0.000010 4.909999e-07 0.000009 0.000061 1 101664.0 1.037275 64 0.000122 0 ref 100 1 0.000010
2 0.000010 8.988595e-07 0.000009 0.000101 1 120378.0 1.207956 64 0.000148 0 custom 100 1 0.000010
3 0.000025 3.640523e-06 0.000018 0.000104 1 53436.0 1.340555 64 0.000114 12288 sparse 100 10 0.000025
4 0.000021 7.798332e-06 0.000014 0.000330 1 60039.0 1.236294 64 0.000567 0 ref 100 10 0.000021
Processing time#
piv = pandas.pivot_table(
df, index=["voc_size", "name"], columns="batch_size", values="average"
)
print(piv)
batch_size 1 10 500 1000 2000
voc_size name
100 custom 0.000010 0.000022 0.000074 0.000122 0.000240
ref 0.000010 0.000021 0.000114 0.000197 0.000361
sparse 0.000012 0.000025 0.000308 0.000435 0.000871
1000 custom 0.000010 0.000025 0.000233 0.000445 0.001150
ref 0.000011 0.000052 0.000511 0.000942 0.002084
sparse 0.000011 0.000033 0.000257 0.000472 0.000864
5000 custom 0.000013 0.000040 0.001422 0.002965 0.011873
ref 0.000021 0.000070 0.002404 0.004740 0.014551
sparse 0.000014 0.000022 0.000227 0.000539 0.000894
10000 custom 0.000013 0.000056 0.002876 0.009929 0.022311
ref 0.000022 0.000106 0.004527 0.014399 0.028333
sparse 0.000012 0.000021 0.000249 0.000508 0.000850
Memory peak#
It is always difficult to estimate. A second process is started to measure the physical memory peak during the execution every ms. The figures is the difference between this peak and the memory when the measurement began.
piv = pandas.pivot_table(
df, index=["voc_size", "name"], columns="batch_size", values="peak"
)
print(piv / 2**20)
batch_size 1 10 500 1000 2000
voc_size name
100 custom 0.000000 0.000000 0.000000 0.000000 0.089844
ref 0.000000 0.000000 0.000000 0.000000 0.000000
sparse 0.011719 0.011719 0.000000 0.000000 0.046875
1000 custom 0.000000 0.003906 0.000000 0.000000 5.246094
ref 0.000000 0.000000 0.003906 0.003906 0.000000
sparse 0.000000 0.000000 0.000000 0.000000 0.082031
5000 custom 0.000000 0.000000 9.539062 20.003906 79.871094
ref 0.000000 0.000000 4.906250 39.054688 77.992188
sparse 0.000000 0.000000 0.000000 0.000000 0.000000
10000 custom 0.000000 0.000000 20.003906 77.996094 154.175781
ref 0.000000 0.000000 39.085938 77.992188 154.175781
sparse 0.000000 0.000000 0.000000 0.000000 0.000000
Graphs#
ax = vhistograms(df)
fig = ax[0, 0].get_figure()
fig.savefig("plot_op_tfidfvectorizer_sparse.png")
Take away#
Sparse works better when the sparsity is big enough and the batch size as well.
Total running time of the script: (1 minutes 35.836 seconds)