Measuring performance of TfIdfVectorizer

The banchmark measures the performance of a TfIdfVectizer along two parameters, the vocabulary size, the batch size whether. It measures the benefit of using sparse implementation through the computation time and the memory peak.

A simple model

We start with a model including only one node TfIdfVectorizer. It only contains unigram. The model processes only sequences of 10 integers. The sparsity of the results is then 10 divided by the size of vocabulary.

import gc
import time
import itertools
from typing import Tuple
import numpy as np
import pandas
from onnx import ModelProto
from onnx.helper import make_attribute
from tqdm import tqdm
from onnxruntime import InferenceSession, SessionOptions
from onnx_extended.ext_test_case import measure_time, unit_test_going
from onnx_extended.memory_peak import start_spying_on
from onnx_extended.reference import CReferenceEvaluator
from onnx_extended.ortops.optim.cpu import get_ort_ext_libs
from onnx_extended.plotting.benchmark import vhistograms


def make_onnx(n_words: int) -> ModelProto:
    from skl2onnx.common.data_types import Int64TensorType, FloatTensorType
    from skl2onnx.algebra.onnx_ops import OnnxTfIdfVectorizer

    # from onnx_array_api.light_api import start
    # onx = (
    #     start(opset=19, opsets={"ai.onnx.ml": 3})
    #     .vin("X", elem_type=TensorProto.INT64)
    #     .ai.onnx.TfIdfVectorizer(
    #     ...
    #     )
    #     .rename(Y)
    #     .vout(elem_type=TensorProto.FLOAT)
    #     .to_onnx()
    # )
    onx = OnnxTfIdfVectorizer(
        "X",
        mode="TF",
        min_gram_length=1,
        max_gram_length=1,
        max_skip_count=0,
        ngram_counts=[0],
        ngram_indexes=np.arange(n_words).tolist(),
        pool_int64s=np.arange(n_words).tolist(),
        output_names=["Y"],
    ).to_onnx(inputs=[("X", Int64TensorType())], outputs=[("Y", FloatTensorType())])
    #     .rename(Y)
    #     .vout(elem_type=TensorProto.FLOAT)
    #     .to_onnx()
    # )
    return onx


onx = make_onnx(7)
ref = CReferenceEvaluator(onx)
got = ref.run(None, {"X": np.array([[0, 1], [2, 3]], dtype=np.int64)})
print(got)
[array([[1., 1., 0., 0., 0., 0., 0.],
       [0., 0., 1., 1., 0., 0., 0.]], dtype=float32)]

It works as expected. Let’s now compare the execution with onnxruntime for different batch size and vocabulary size.

Benchmark

def make_sessions(
    onx: ModelProto,
) -> Tuple[InferenceSession, InferenceSession, InferenceSession]:
    # first: onnxruntime
    ref = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])

    # second: custom kernel equivalent to the onnxruntime implementation
    for node in onx.graph.node:
        if node.op_type == "TfIdfVectorizer":
            node.domain = "onnx_extended.ortops.optim.cpu"
            # new_add = make_attribute("sparse", 1)
            # node.attribute.append(new_add)

    d = onx.opset_import.add()
    d.domain = "onnx_extended.ortops.optim.cpu"
    d.version = 1

    r = get_ort_ext_libs()
    opts = SessionOptions()
    opts.register_custom_ops_library(r[0])
    cus = InferenceSession(
        onx.SerializeToString(), opts, providers=["CPUExecutionProvider"]
    )

    # third: with sparse
    for node in onx.graph.node:
        if node.op_type == "TfIdfVectorizer":
            new_add = make_attribute("sparse", 1)
            node.attribute.append(new_add)
    cussp = InferenceSession(
        onx.SerializeToString(), opts, providers=["CPUExecutionProvider"]
    )

    return ref, cus, cussp


if unit_test_going():
    vocabulary_sizes = [10, 20]
    batch_sizes = [5, 10]
else:
    vocabulary_sizes = [100, 1000, 5000, 10000]
    batch_sizes = [1, 10, 500, 1000, 2000]
confs = list(itertools.product(vocabulary_sizes, batch_sizes))

data = []
for voc_size, batch_size in tqdm(confs):
    onx = make_onnx(voc_size)
    ref, cus, sparse = make_sessions(onx)
    gc.collect()

    feeds = dict(
        X=(np.arange(batch_size * 10) % voc_size)
        .reshape((batch_size, -1))
        .astype(np.int64)
    )

    # sparse
    p = start_spying_on(delay=0.0001)
    sparse.run(None, feeds)
    obs = measure_time(
        lambda sparse=sparse, feeds=feeds: sparse.run(None, feeds), max_time=1
    )
    mem = p.stop()
    obs["peak"] = mem["cpu"].max_peak - mem["cpu"].begin
    obs["name"] = "sparse"
    obs.update(dict(voc_size=voc_size, batch_size=batch_size))
    data.append(obs)
    time.sleep(0.1)

    # reference
    p = start_spying_on(delay=0.0001)
    ref.run(None, feeds)
    obs = measure_time(lambda ref=ref, feeds=feeds: ref.run(None, feeds), max_time=1)
    mem = p.stop()
    obs["peak"] = mem["cpu"].max_peak - mem["cpu"].begin
    obs["name"] = "ref"
    obs.update(dict(voc_size=voc_size, batch_size=batch_size))
    data.append(obs)
    time.sleep(0.1)

    # custom
    p = start_spying_on(delay=0.0001)
    cus.run(None, feeds)
    obs = measure_time(lambda cus=cus, feeds=feeds: cus.run(None, feeds), max_time=1)
    mem = p.stop()
    obs["peak"] = mem["cpu"].max_peak - mem["cpu"].begin
    obs["name"] = "custom"
    obs.update(dict(voc_size=voc_size, batch_size=batch_size))
    data.append(obs)
    time.sleep(0.1)

    del sparse
    del cus
    del ref
    del feeds

df = pandas.DataFrame(data)
df["time"] = df["average"]
df.to_csv("plot_op_tfidfvectorizer_sparse.csv", index=False)
print(df.head())
  0%|          | 0/20 [00:00<?, ?it/s]
  5%|▌         | 1/20 [00:04<01:25,  4.51s/it]
 10%|█         | 2/20 [00:09<01:25,  4.73s/it]
 15%|█▌        | 3/20 [00:14<01:22,  4.82s/it]
 20%|██        | 4/20 [00:18<01:14,  4.67s/it]
 25%|██▌       | 5/20 [00:22<01:06,  4.46s/it]
 30%|███       | 6/20 [00:27<01:01,  4.39s/it]
 35%|███▌      | 7/20 [00:31<00:56,  4.38s/it]
 40%|████      | 8/20 [00:35<00:51,  4.32s/it]
 45%|████▌     | 9/20 [00:39<00:47,  4.31s/it]
 50%|█████     | 10/20 [00:44<00:43,  4.31s/it]
 55%|█████▌    | 11/20 [00:48<00:37,  4.19s/it]
 60%|██████    | 12/20 [00:52<00:33,  4.14s/it]
 65%|██████▌   | 13/20 [00:56<00:29,  4.16s/it]
 70%|███████   | 14/20 [01:00<00:25,  4.26s/it]
 75%|███████▌  | 15/20 [01:05<00:21,  4.31s/it]
 80%|████████  | 16/20 [01:09<00:16,  4.22s/it]
 85%|████████▌ | 17/20 [01:13<00:12,  4.16s/it]
 90%|█████████ | 18/20 [01:18<00:08,  4.41s/it]
 95%|█████████▌| 19/20 [01:22<00:04,  4.37s/it]
100%|██████████| 20/20 [01:27<00:00,  4.45s/it]
100%|██████████| 20/20 [01:27<00:00,  4.36s/it]
    average  deviation  min_exec  max_exec  repeat   number     ttime  context_size  warmup_time      peak    name  voc_size  batch_size      time
0  0.000023   0.000002  0.000021  0.000106       1  52701.0  1.191671            64     0.000237  18178048  sparse       100           1  0.000023
1  0.000028   0.000006  0.000021  0.000219       1  40952.0  1.147997            64     0.000290     57344     ref       100           1  0.000028
2  0.000024   0.000007  0.000015  0.000055       1  48615.0  1.158172            64     0.000161     45056  custom       100           1  0.000024
3  0.000030   0.000004  0.000025  0.000057       1  48813.0  1.481819            64     0.000150     45056  sparse       100          10  0.000030
4  0.000030   0.000006  0.000023  0.000407       1  38366.0  1.138974            64     0.000836     28672     ref       100          10  0.000030

Processing time

piv = pandas.pivot_table(
    df, index=["voc_size", "name"], columns="batch_size", values="average"
)
print(piv)
batch_size           1         10        500       1000      2000
voc_size name
100      custom  0.000024  0.000028  0.000076  0.000135  0.000267
         ref     0.000028  0.000030  0.000111  0.000185  0.000395
         sparse  0.000023  0.000030  0.000360  0.000479  0.000954
1000     custom  0.000015  0.000029  0.000274  0.000520  0.001750
         ref     0.000015  0.000037  0.000578  0.001072  0.002414
         sparse  0.000016  0.000028  0.000259  0.000545  0.000952
5000     custom  0.000015  0.000048  0.002101  0.003968  0.012351
         ref     0.000022  0.000075  0.002928  0.005921  0.015786
         sparse  0.000017  0.000027  0.000273  0.000545  0.001071
10000    custom  0.000017  0.000062  0.004017  0.012112  0.029190
         ref     0.000028  0.000145  0.006155  0.016129  0.034435
         sparse  0.000017  0.000030  0.000258  0.000549  0.000934

Memory peak

It is always difficult to estimate. A second process is started to measure the physical memory peak during the execution every ms. The figures is the difference between this peak and the memory when the measurement began.

piv = pandas.pivot_table(
    df, index=["voc_size", "name"], columns="batch_size", values="peak"
)
print(piv / 2**20)
batch_size            1         10         500        1000        2000
voc_size name
100      custom   0.042969  0.027344   0.042969   0.031250    1.230469
         ref      0.054688  0.027344   0.289062   0.453125    0.187500
         sparse  17.335938  0.042969   0.023438   0.015625    0.046875
1000     custom   0.000000  0.003906   1.910156   3.628906   10.000000
         ref      0.000000  0.000000   3.761719   3.613281   14.843750
         sparse   0.000000  0.000000   0.000000   0.000000    0.019531
5000     custom   0.000000  0.121094  10.710938  21.003906   78.156250
         ref      0.000000  0.000000  19.000000  38.585938   78.003906
         sparse   0.000000  0.000000   0.000000   0.066406    0.000000
10000    custom   0.000000  0.000000  20.003906  78.152344  154.179688
         ref      0.000000  0.000000  22.851562  78.148438  154.226562
         sparse   0.000000  0.000000   0.000000   0.000000    0.000000

Graphs

ax = vhistograms(df)
fig = ax[0, 0].get_figure()
fig.savefig("plot_op_tfidfvectorizer_sparse.png")
Compares Implementations of TfIdfVectorizer

Take away

Sparse works better when the sparsity is big enough and the batch size as well.

Total running time of the script: (1 minutes 46.085 seconds)

Gallery generated by Sphinx-Gallery