Note
Go to the end to download the full example code
numpy.digitize as a tree#
Function numpy.digitize()
transforms a real variable
into a discrete one by returning the buckets the variable
falls into. This bucket can be efficiently retrieved by doing a
binary search over the bins. That’s equivalent to decision tree.
Function digitize2tree
.
Simple example#
import numpy
import matplotlib.pyplot as plt
from onnxruntime import InferenceSession
from pandas import DataFrame, pivot, pivot_table
from skl2onnx import to_onnx
from sklearn.tree import export_text
from tqdm import tqdm
from mlinsights.ext_test_case import measure_time
from mlinsights.mltree import digitize2tree
x = numpy.array([0.2, 6.4, 3.0, 1.6])
bins = numpy.array([0.0, 1.0, 2.5, 4.0, 7.0])
expected = numpy.digitize(x, bins, right=True)
tree = digitize2tree(bins, right=True)
pred = tree.predict(x.reshape((-1, 1)))
print(expected, pred)
[2023-11-05 12:49:16,083] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[1 4 3 2] [1. 4. 3. 2.]
The tree looks like the following.
print(export_text(tree, feature_names=["x"]))
|--- x <= 2.50
| |--- x <= 1.00
| | |--- x <= 0.00
| | | |--- value: [0.00]
| | |--- x > 0.00
| | | |--- value: [1.00]
| |--- x > 1.00
| | |--- value: [2.00]
|--- x > 2.50
| |--- x <= 4.00
| | |--- x <= 2.50
| | | |--- value: [2.00]
| | |--- x > 2.50
| | | |--- value: [3.00]
| |--- x > 4.00
| | |--- x <= 7.00
| | | |--- x <= 4.00
| | | | |--- value: [3.00]
| | | |--- x > 4.00
| | | | |--- value: [4.00]
| | |--- x > 7.00
| | | |--- value: [5.00]
Benchmark#
Let’s measure the processing time. numpy should be much faster than scikit-learn as it is adding many verification. However, the benchmark also includes a conversion of the tree into ONNX and measure the processing time with onnxruntime.
obs = []
for shape in tqdm([1, 10, 100, 1000, 10000, 100000]):
x = numpy.random.random(shape).astype(numpy.float32)
if shape < 1000:
repeat = number = 100
else:
repeat = number = 10
for n_bins in [1, 10, 100]:
bins = (numpy.arange(n_bins) / n_bins).astype(numpy.float32)
ti = measure_time(
"numpy.digitize(x, bins, right=True)",
context={"numpy": numpy, "x": x, "bins": bins},
div_by_number=True,
repeat=repeat,
number=number,
)
ti["name"] = "numpy"
ti["n_bins"] = n_bins
ti["shape"] = shape
obs.append(ti)
tree = digitize2tree(bins, right=True)
ti = measure_time(
"tree.predict(x)",
context={"numpy": numpy, "x": x.reshape((-1, 1)), "tree": tree},
div_by_number=True,
repeat=repeat,
number=number,
)
ti["name"] = "sklearn"
ti["n_bins"] = n_bins
ti["shape"] = shape
obs.append(ti)
onx = to_onnx(tree, x.reshape((-1, 1)), target_opset=15)
sess = InferenceSession(
onx.SerializeToString(), providers=["CPUExecutionProvider"]
)
ti = measure_time(
"sess.run(None, {'X': x})",
context={"numpy": numpy, "x": x.reshape((-1, 1)), "sess": sess},
div_by_number=True,
repeat=repeat,
number=number,
)
ti["name"] = "ort"
ti["n_bins"] = n_bins
ti["shape"] = shape
obs.append(ti)
df = DataFrame(obs)
piv = pivot_table(
data=df, index="shape", columns=["n_bins", "name"], values=["average"]
)
print(piv)
0%| | 0/6 [00:00<?, ?it/s]
17%|█▋ | 1/6 [00:03<00:15, 3.04s/it]
33%|███▎ | 2/6 [00:07<00:14, 3.61s/it]
50%|█████ | 3/6 [00:10<00:11, 3.74s/it]
67%|██████▋ | 4/6 [00:11<00:04, 2.31s/it]
83%|████████▎ | 5/6 [00:11<00:01, 1.60s/it]
100%|██████████| 6/6 [00:13<00:00, 1.71s/it]
100%|██████████| 6/6 [00:13<00:00, 2.22s/it]
average ...
n_bins 1 ... 100
name numpy ort sklearn ... numpy ort sklearn
shape ...
1 0.000003 0.000009 0.000082 ... 0.000003 0.000012 0.000094
10 0.000004 0.000016 0.000130 ... 0.000004 0.000012 0.000112
100 0.000004 0.000013 0.000095 ... 0.000006 0.000017 0.000129
1000 0.000011 0.000029 0.000122 ... 0.000095 0.000078 0.000292
10000 0.000135 0.000091 0.000263 ... 0.000643 0.000147 0.000816
100000 0.000840 0.000181 0.000529 ... 0.004545 0.001339 0.005261
[6 rows x 9 columns]
Plotting#
n_bins = list(sorted(set(df.n_bins)))
fig, ax = plt.subplots(1, len(n_bins), figsize=(14, 4))
for i, nb in enumerate(n_bins):
piv = pivot(
data=df[df.n_bins == nb], index="shape", columns="name", values="average"
)
piv.plot(
title="Benchmark digitize / onnxruntime\nn_bins=%d" % nb,
logx=True,
logy=True,
ax=ax[i],
)
Total running time of the script: (0 minutes 21.778 seconds)