TreeEnsemble, dense, and sparse#

The example benchmarks the sparse implementation for TreeEnsemble. The default set of optimized parameters is very short and is meant to be executed fast. Many more parameters can be tried.

python plot_op_tree_ensemble_sparse --scenario=LONG

To change the training parameters:

python plot_op_tree_ensemble_sparse.py
    --n_trees=100
    --max_depth=10
    --n_features=50
    --sparsity=0.9
    --batch_size=100000

Another example with a full list of parameters:

python plot_op_tree_ensemble_sparse.py

–n_trees=100 –max_depth=10 –n_features=50 –batch_size=100000 –sparsity=0.9 –tries=3 –scenario=CUSTOM –parallel_tree=80,40 –parallel_tree_N=128,64 –parallel_N=50,25 –batch_size_tree=1,2 –batch_size_rows=1,2 –use_node3=0

Another example:

python plot_op_tree_ensemble_sparse.py
    --n_trees=100 --n_features=10 --batch_size=10000 --max_depth=8 -s SHORT
import logging
import os
import timeit
from typing import Tuple
import numpy
import onnx
from onnx import ModelProto, TensorProto
from onnx.helper import make_graph, make_model, make_tensor_value_info
from pandas import DataFrame, concat
from sklearn.datasets import make_regression
from sklearn.ensemble import RandomForestRegressor
from skl2onnx import to_onnx
from onnxruntime import InferenceSession, SessionOptions
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
from onnx_extended.ortops.optim.cpu import get_ort_ext_libs
from onnx_extended.ortops.optim.optimize import (
    change_onnx_operator_domain,
    get_node_attribute,
    optimize_model,
)
from onnx_extended.tools.onnx_nodes import multiply_tree
from onnx_extended.validation.cpu._validation import dense_to_sparse_struct
from onnx_extended.plotting.benchmark import hhistograms
from onnx_extended.args import get_parsed_args
from onnx_extended.ext_test_case import unit_test_going

logging.getLogger("matplotlib.font_manager").setLevel(logging.ERROR)

script_args = get_parsed_args(
    "plot_op_tree_ensemble_sparse",
    description=__doc__,
    scenarios={
        "SHORT": "short optimization (default)",
        "LONG": "test more options",
        "CUSTOM": "use values specified by the command line",
    },
    sparsity=(0.99, "input sparsity"),
    n_features=(2 if unit_test_going() else 500, "number of features to generate"),
    n_trees=(3 if unit_test_going() else 10, "number of trees to train"),
    max_depth=(2 if unit_test_going() else 10, "max_depth"),
    batch_size=(1000 if unit_test_going() else 1000, "batch size"),
    parallel_tree=("80,160,40", "values to try for parallel_tree"),
    parallel_tree_N=("256,128,64", "values to try for parallel_tree_N"),
    parallel_N=("100,50,25", "values to try for parallel_N"),
    batch_size_tree=("2,4,8", "values to try for batch_size_tree"),
    batch_size_rows=("2,4,8", "values to try for batch_size_rows"),
    use_node3=("0,1", "values to try for use_node3"),
    expose="",
    n_jobs=("-1", "number of jobs to train the RandomForestRegressor"),
)

Training a model#

def train_model(
    batch_size: int, n_features: int, n_trees: int, max_depth: int, sparsity: float
) -> Tuple[str, numpy.ndarray, numpy.ndarray]:
    filename = (
        f"plot_op_tree_ensemble_sparse-f{n_features}-{n_trees}-"
        f"d{max_depth}-s{sparsity}.onnx"
    )
    if not os.path.exists(filename):
        X, y = make_regression(
            batch_size + max(batch_size, 2 ** (max_depth + 1)),
            n_features=n_features,
            n_targets=1,
        )
        mask = numpy.random.rand(*X.shape) <= sparsity
        X[mask] = 0
        X, y = X.astype(numpy.float32), y.astype(numpy.float32)

        print(f"Training to get {filename!r} with X.shape={X.shape}")
        # To be faster, we train only 1 tree.
        model = RandomForestRegressor(
            1, max_depth=max_depth, verbose=2, n_jobs=int(script_args.n_jobs)
        )
        model.fit(X[:-batch_size], y[:-batch_size])
        onx = to_onnx(model, X[:1])

        # And wd multiply the trees.
        node = multiply_tree(onx.graph.node[0], n_trees)
        onx = make_model(
            make_graph([node], onx.graph.name, onx.graph.input, onx.graph.output),
            domain=onx.domain,
            opset_imports=onx.opset_import,
        )

        with open(filename, "wb") as f:
            f.write(onx.SerializeToString())
    else:
        X, y = make_regression(batch_size, n_features=n_features, n_targets=1)
        mask = numpy.random.rand(*X.shape) <= sparsity
        X[mask] = 0
        X, y = X.astype(numpy.float32), y.astype(numpy.float32)
    Xb, yb = X[-batch_size:].copy(), y[-batch_size:].copy()
    return filename, Xb, yb


def measure_sparsity(x):
    f = x.flatten()
    return float((f == 0).astype(numpy.int64).sum()) / float(x.size)


batch_size = script_args.batch_size
n_features = script_args.n_features
n_trees = script_args.n_trees
max_depth = script_args.max_depth
sparsity = script_args.sparsity

print(f"batch_size={batch_size}")
print(f"n_features={n_features}")
print(f"n_trees={n_trees}")
print(f"max_depth={max_depth}")
print(f"sparsity={sparsity}")
batch_size=1000
n_features=500
n_trees=10
max_depth=10
sparsity=0.99

training

filename, Xb, yb = train_model(batch_size, n_features, n_trees, max_depth, sparsity)

print(f"Xb.shape={Xb.shape}")
print(f"yb.shape={yb.shape}")
print(f"measured sparsity={measure_sparsity(Xb)}")
Training to get 'plot_op_tree_ensemble_sparse-f500-10-d10-s0.99.onnx' with X.shape=(3048, 500)
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
building tree 1 of 1
[Parallel(n_jobs=-1)]: Done   1 out of   1 | elapsed:    0.1s finished
Xb.shape=(1000, 500)
yb.shape=(1000,)
measured sparsity=0.989996

Rewrite the onnx file to use a different kernel#

The custom kernel is mapped to a custom operator with the same name the attributes and domain = “onnx_extented.ortops.optim.cpu”. We call a function to do that replacement. First the current model.

with open(filename, "rb") as f:
    onx = onnx.load(f)
print(onnx_simple_text_plot(onx))
opset: domain='ai.onnx.ml' version=1
opset: domain='' version=19
input: name='X' type=dtype('float32') shape=['', 500]
TreeEnsembleRegressor(X, n_targets=1, nodes_falsenodeids=630:[58,7,6...62,0,0], nodes_featureids=630:[386,263,69...290,264,27], nodes_hitrates=630:[1.0,1.0...1.0,1.0], nodes_missing_value_tracks_true=630:[0,0,0...0,0,0], nodes_modes=630:[b'BRANCH_LEQ',b'BRANCH_LEQ'...b'LEAF',b'LEAF'], nodes_nodeids=630:[0,1,2...60,61,62], nodes_treeids=630:[0,0,0...9,9,9], nodes_truenodeids=630:[1,2,3...61,0,0], nodes_values=630:[1.0825226306915283,-1.293250322341919...-0.006145985797047615,0.1449897587299347], post_transform=b'NONE', target_ids=320:[0,0,0...0,0,0], target_nodeids=320:[4,5,6...59,61,62], target_treeids=320:[0,0,0...9,9,9], target_weights=320:[-429.5425109863281,-346.9263610839844...412.6749267578125,341.1100158691406]) -> variable
output: name='variable' type=dtype('float32') shape=['', 1]

And then the modified model.

def transform_model(model, use_sparse=False, **kwargs):
    onx = ModelProto()
    onx.ParseFromString(model.SerializeToString())
    att = get_node_attribute(onx.graph.node[0], "nodes_modes")
    modes = ",".join(map(lambda s: s.decode("ascii"), att.strings)).replace(
        "BRANCH_", ""
    )
    if use_sparse and "new_op_type" not in kwargs:
        kwargs["new_op_type"] = "TreeEnsembleRegressorSparse"
    if use_sparse:
        # with sparse tensor, missing value means 0
        att = get_node_attribute(onx.graph.node[0], "nodes_values")
        thresholds = numpy.array(att.floats, dtype=numpy.float32)
        missing_true = (thresholds >= 0).astype(numpy.int64)
        kwargs["nodes_missing_value_tracks_true"] = missing_true
    new_onx = change_onnx_operator_domain(
        onx,
        op_type="TreeEnsembleRegressor",
        op_domain="ai.onnx.ml",
        new_op_domain="onnx_extented.ortops.optim.cpu",
        nodes_modes=modes,
        **kwargs,
    )
    if use_sparse:
        del new_onx.graph.input[:]
        new_onx.graph.input.append(
            make_tensor_value_info("X", TensorProto.FLOAT, (None,))
        )
    return new_onx


print("Tranform model to add a custom node.")
onx_modified = transform_model(onx)
print(f"Save into {filename + 'modified.onnx'!r}.")
with open(filename + "modified.onnx", "wb") as f:
    f.write(onx_modified.SerializeToString())
print("done.")
print(onnx_simple_text_plot(onx_modified))
Tranform model to add a custom node.
Save into 'plot_op_tree_ensemble_sparse-f500-10-d10-s0.99.onnxmodified.onnx'.
done.
opset: domain='ai.onnx.ml' version=1
opset: domain='' version=19
opset: domain='onnx_extented.ortops.optim.cpu' version=1
input: name='X' type=dtype('float32') shape=['', 500]
TreeEnsembleRegressor[onnx_extented.ortops.optim.cpu](X, nodes_modes=b'LEQ,LEQ,LEQ,LEQ,LEAF,LEAF,LEAF,LEQ,LEQ...LEAF,LEAF', n_targets=1, nodes_falsenodeids=630:[58,7,6...62,0,0], nodes_featureids=630:[386,263,69...290,264,27], nodes_hitrates=630:[1.0,1.0...1.0,1.0], nodes_missing_value_tracks_true=630:[0,0,0...0,0,0], nodes_nodeids=630:[0,1,2...60,61,62], nodes_treeids=630:[0,0,0...9,9,9], nodes_truenodeids=630:[1,2,3...61,0,0], nodes_values=630:[1.0825226306915283,-1.293250322341919...-0.006145985797047615,0.1449897587299347], post_transform=b'NONE', target_ids=320:[0,0,0...0,0,0], target_nodeids=320:[4,5,6...59,61,62], target_treeids=320:[0,0,0...9,9,9], target_weights=320:[-429.5425109863281,-346.9263610839844...412.6749267578125,341.1100158691406]) -> variable
output: name='variable' type=dtype('float32') shape=['', 1]

Same with sparse.

print("Same transformation but with sparse.")
onx_modified_sparse = transform_model(onx, use_sparse=True)
print(f"Save into {filename + 'modified.sparse.onnx'!r}.")
with open(filename + "modified.sparse.onnx", "wb") as f:
    f.write(onx_modified_sparse.SerializeToString())
print("done.")
print(onnx_simple_text_plot(onx_modified_sparse))
Same transformation but with sparse.
Save into 'plot_op_tree_ensemble_sparse-f500-10-d10-s0.99.onnxmodified.sparse.onnx'.
done.
opset: domain='ai.onnx.ml' version=1
opset: domain='' version=19
opset: domain='onnx_extented.ortops.optim.cpu' version=1
input: name='X' type=dtype('float32') shape=['']
TreeEnsembleRegressorSparse[onnx_extented.ortops.optim.cpu](X, nodes_missing_value_tracks_true=630:[1,0,1...0,0,1], nodes_modes=b'LEQ,LEQ,LEQ,LEQ,LEAF,LEAF,LEAF,LEQ,LEQ...LEAF,LEAF', n_targets=1, nodes_falsenodeids=630:[58,7,6...62,0,0], nodes_featureids=630:[386,263,69...290,264,27], nodes_hitrates=630:[1.0,1.0...1.0,1.0], nodes_nodeids=630:[0,1,2...60,61,62], nodes_treeids=630:[0,0,0...9,9,9], nodes_truenodeids=630:[1,2,3...61,0,0], nodes_values=630:[1.0825226306915283,-1.293250322341919...-0.006145985797047615,0.1449897587299347], post_transform=b'NONE', target_ids=320:[0,0,0...0,0,0], target_nodeids=320:[4,5,6...59,61,62], target_treeids=320:[0,0,0...9,9,9], target_weights=320:[-429.5425109863281,-346.9263610839844...412.6749267578125,341.1100158691406]) -> variable
output: name='variable' type=dtype('float32') shape=['', 1]

Comparing onnxruntime and the custom kernel#

print(f"Loading {filename!r}")
sess_ort = InferenceSession(filename, providers=["CPUExecutionProvider"])

r = get_ort_ext_libs()
print(f"Creating SessionOptions with {r!r}")
opts = SessionOptions()
if r is not None:
    opts.register_custom_ops_library(r[0])

print(f"Loading modified {filename!r}")
sess_cus = InferenceSession(
    onx_modified.SerializeToString(), opts, providers=["CPUExecutionProvider"]
)

print(f"Loading modified sparse {filename!r}")
sess_cus_sparse = InferenceSession(
    onx_modified_sparse.SerializeToString(), opts, providers=["CPUExecutionProvider"]
)


print(f"Running once with shape {Xb.shape}.")
base = sess_ort.run(None, {"X": Xb})[0]

print(f"Running modified with shape {Xb.shape}.")
got = sess_cus.run(None, {"X": Xb})[0]
print("done.")

Xb_sp = dense_to_sparse_struct(Xb)
print(f"Running modified sparse with shape {Xb_sp.shape}.")
got_sparse = sess_cus_sparse.run(None, {"X": Xb_sp})[0]
print("done.")
Loading 'plot_op_tree_ensemble_sparse-f500-10-d10-s0.99.onnx'
Creating SessionOptions with ['/home/xadupre/github/onnx-extended/onnx_extended/ortops/optim/cpu/libortops_optim_cpu.so']
Loading modified 'plot_op_tree_ensemble_sparse-f500-10-d10-s0.99.onnx'
Loading modified sparse 'plot_op_tree_ensemble_sparse-f500-10-d10-s0.99.onnx'
Running once with shape (1000, 500).
Running modified with shape (1000, 500).
done.
Running modified sparse with shape (10060,).
done.

Discrepancies?

diff = numpy.abs(base - got).max()
print(f"Discrepancies: {diff}")

diff = numpy.abs(base - got_sparse).max()
print(f"Discrepancies sparse: {diff}")
Discrepancies: 0.00030517578125
Discrepancies sparse: 0.00030517578125

Simple verification#

Baseline with onnxruntime.

t1 = timeit.timeit(lambda: sess_ort.run(None, {"X": Xb}), number=50)
print(f"baseline: {t1}")
baseline: 0.00941400000010617

The custom implementation.

t2 = timeit.timeit(lambda: sess_cus.run(None, {"X": Xb}), number=50)
print(f"new time: {t2}")
new time: 0.021795399999973597

The custom sparse implementation.

t3 = timeit.timeit(lambda: sess_cus_sparse.run(None, {"X": Xb_sp}), number=50)
print(f"new time sparse: {t3}")
new time sparse: 0.022183099999892875

Time for comparison#

The custom kernel supports the same attributes as TreeEnsembleRegressor plus new ones to tune the parallelization. They can be seen in tree_ensemble.cc. Let’s try out many possibilities. The default values are the first ones.

if unit_test_going():
    optim_params = dict(
        parallel_tree=[40],  # default is 80
        parallel_tree_N=[128],  # default is 128
        parallel_N=[50, 25],  # default is 50
        batch_size_tree=[1],  # default is 1
        batch_size_rows=[1],  # default is 1
        use_node3=[0],  # default is 0
    )
elif script_args.scenario in (None, "SHORT"):
    optim_params = dict(
        parallel_tree=[80, 40],  # default is 80
        parallel_tree_N=[128, 64],  # default is 128
        parallel_N=[50, 25],  # default is 50
        batch_size_tree=[1],  # default is 1
        batch_size_rows=[1],  # default is 1
        use_node3=[0],  # default is 0
    )
elif script_args.scenario == "LONG":
    optim_params = dict(
        parallel_tree=[80, 160, 40],
        parallel_tree_N=[256, 128, 64],
        parallel_N=[100, 50, 25],
        batch_size_tree=[1, 2, 4, 8],
        batch_size_rows=[1, 2, 4, 8],
        use_node3=[0, 1],
    )
elif script_args.scenario == "CUSTOM":
    optim_params = dict(
        parallel_tree=list(int(i) for i in script_args.parallel_tree.split(",")),
        parallel_tree_N=list(int(i) for i in script_args.parallel_tree_N.split(",")),
        parallel_N=list(int(i) for i in script_args.parallel_N.split(",")),
        batch_size_tree=list(int(i) for i in script_args.batch_size_tree.split(",")),
        batch_size_rows=list(int(i) for i in script_args.batch_size_rows.split(",")),
        use_node3=list(int(i) for i in script_args.use_node3.split(",")),
    )
else:
    raise ValueError(
        f"Unknown scenario {script_args.scenario!r}, use --help to get them."
    )

cmds = []
for att, value in optim_params.items():
    cmds.append(f"--{att}={','.join(map(str, value))}")
print("Full list of optimization parameters:")
print(" ".join(cmds))
Full list of optimization parameters:
--parallel_tree=80,40 --parallel_tree_N=128,64 --parallel_N=50,25 --batch_size_tree=1 --batch_size_rows=1 --use_node3=0

Then the optimization for dense

def create_session(onx):
    opts = SessionOptions()
    r = get_ort_ext_libs()
    if r is None:
        raise RuntimeError("No custom implementation available.")
    opts.register_custom_ops_library(r[0])
    return InferenceSession(
        onx.SerializeToString(), opts, providers=["CPUExecutionProvider"]
    )


res = optimize_model(
    onx,
    feeds={"X": Xb},
    transform=transform_model,
    session=create_session,
    baseline=lambda onx: InferenceSession(
        onx.SerializeToString(), providers=["CPUExecutionProvider"]
    ),
    params=optim_params,
    verbose=True,
    number=script_args.number,
    repeat=script_args.repeat,
    warmup=script_args.warmup,
    sleep=script_args.sleep,
    n_tries=script_args.tries,
)
  0%|          | 0/16 [00:00<?, ?it/s]
i=1/16 TRY=0 //tree=80 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0:   0%|          | 0/16 [00:00<?, ?it/s]
i=1/16 TRY=0 //tree=80 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0:   6%|▋         | 1/16 [00:00<00:03,  3.83it/s]
i=2/16 TRY=0 //tree=80 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0  ~=0.56x:   6%|▋         | 1/16 [00:00<00:03,  3.83it/s]
i=2/16 TRY=0 //tree=80 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0  ~=0.56x:  12%|█▎        | 2/16 [00:00<00:02,  5.72it/s]
i=3/16 TRY=0 //tree=80 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0  ~=1.12x:  12%|█▎        | 2/16 [00:00<00:02,  5.72it/s]
i=3/16 TRY=0 //tree=80 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0  ~=1.12x:  19%|█▉        | 3/16 [00:00<00:01,  6.53it/s]
i=4/16 TRY=0 //tree=80 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0  ~=1.17x:  19%|█▉        | 3/16 [00:00<00:01,  6.53it/s]
i=4/16 TRY=0 //tree=80 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0  ~=1.17x:  25%|██▌       | 4/16 [00:00<00:01,  7.28it/s]
i=5/16 TRY=0 //tree=40 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0  ~=1.17x:  25%|██▌       | 4/16 [00:00<00:01,  7.28it/s]
i=5/16 TRY=0 //tree=40 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0  ~=1.17x:  31%|███▏      | 5/16 [00:00<00:01,  7.39it/s]
i=6/16 TRY=0 //tree=40 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0  ~=1.17x:  31%|███▏      | 5/16 [00:00<00:01,  7.39it/s]
i=6/16 TRY=0 //tree=40 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0  ~=1.17x:  38%|███▊      | 6/16 [00:00<00:01,  7.58it/s]
i=7/16 TRY=0 //tree=40 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0  ~=1.20x:  38%|███▊      | 6/16 [00:00<00:01,  7.58it/s]
i=7/16 TRY=0 //tree=40 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0  ~=1.20x:  44%|████▍     | 7/16 [00:01<00:01,  7.62it/s]
i=8/16 TRY=0 //tree=40 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0  ~=1.20x:  44%|████▍     | 7/16 [00:01<00:01,  7.62it/s]
i=8/16 TRY=0 //tree=40 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0  ~=1.20x:  50%|█████     | 8/16 [00:01<00:01,  7.98it/s]
i=9/16 TRY=1 //tree=80 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0  ~=1.20x:  50%|█████     | 8/16 [00:01<00:01,  7.98it/s]
i=9/16 TRY=1 //tree=80 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0  ~=1.20x:  56%|█████▋    | 9/16 [00:01<00:00,  7.61it/s]
i=10/16 TRY=1 //tree=80 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0  ~=1.20x:  56%|█████▋    | 9/16 [00:01<00:00,  7.61it/s]
i=10/16 TRY=1 //tree=80 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0  ~=1.20x:  62%|██████▎   | 10/16 [00:01<00:00,  7.26it/s]
i=11/16 TRY=1 //tree=80 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0  ~=1.20x:  62%|██████▎   | 10/16 [00:01<00:00,  7.26it/s]
i=11/16 TRY=1 //tree=80 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0  ~=1.20x:  69%|██████▉   | 11/16 [00:01<00:00,  7.41it/s]
i=12/16 TRY=1 //tree=80 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0  ~=1.20x:  69%|██████▉   | 11/16 [00:01<00:00,  7.41it/s]
i=12/16 TRY=1 //tree=80 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0  ~=1.20x:  75%|███████▌  | 12/16 [00:01<00:00,  7.79it/s]
i=13/16 TRY=1 //tree=40 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0  ~=1.79x:  75%|███████▌  | 12/16 [00:01<00:00,  7.79it/s]
i=13/16 TRY=1 //tree=40 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0  ~=1.79x:  81%|████████▏ | 13/16 [00:01<00:00,  7.54it/s]
i=14/16 TRY=1 //tree=40 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0  ~=1.79x:  81%|████████▏ | 13/16 [00:01<00:00,  7.54it/s]
i=14/16 TRY=1 //tree=40 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0  ~=1.79x:  88%|████████▊ | 14/16 [00:01<00:00,  7.46it/s]
i=15/16 TRY=1 //tree=40 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0  ~=1.79x:  88%|████████▊ | 14/16 [00:01<00:00,  7.46it/s]
i=15/16 TRY=1 //tree=40 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0  ~=1.79x:  94%|█████████▍| 15/16 [00:02<00:00,  7.33it/s]
i=16/16 TRY=1 //tree=40 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0  ~=1.79x:  94%|█████████▍| 15/16 [00:02<00:00,  7.33it/s]
i=16/16 TRY=1 //tree=40 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0  ~=1.79x: 100%|██████████| 16/16 [00:02<00:00,  7.23it/s]
i=16/16 TRY=1 //tree=40 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0  ~=1.79x: 100%|██████████| 16/16 [00:02<00:00,  7.21it/s]

Then the optimization for sparse

res_sparse = optimize_model(
    onx,
    feeds={"X": Xb_sp},
    transform=lambda *args, **kwargs: transform_model(*args, use_sparse=True, **kwargs),
    session=create_session,
    params=optim_params,
    verbose=True,
    number=script_args.number,
    repeat=script_args.repeat,
    warmup=script_args.warmup,
    sleep=script_args.sleep,
    n_tries=script_args.tries,
)
  0%|          | 0/16 [00:00<?, ?it/s]
i=1/16 TRY=0 //tree=80 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0:   0%|          | 0/16 [00:00<?, ?it/s]
i=1/16 TRY=0 //tree=80 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0:   6%|▋         | 1/16 [00:00<00:03,  4.54it/s]
i=2/16 TRY=0 //tree=80 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0:   6%|▋         | 1/16 [00:00<00:03,  4.54it/s]
i=2/16 TRY=0 //tree=80 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0:  12%|█▎        | 2/16 [00:00<00:03,  4.61it/s]
i=3/16 TRY=0 //tree=80 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0:  12%|█▎        | 2/16 [00:00<00:03,  4.61it/s]
i=3/16 TRY=0 //tree=80 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0:  19%|█▉        | 3/16 [00:00<00:02,  4.46it/s]
i=4/16 TRY=0 //tree=80 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0:  19%|█▉        | 3/16 [00:00<00:02,  4.46it/s]
i=4/16 TRY=0 //tree=80 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0:  25%|██▌       | 4/16 [00:00<00:02,  4.35it/s]
i=5/16 TRY=0 //tree=40 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0:  25%|██▌       | 4/16 [00:00<00:02,  4.35it/s]
i=5/16 TRY=0 //tree=40 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0:  31%|███▏      | 5/16 [00:01<00:02,  4.57it/s]
i=6/16 TRY=0 //tree=40 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0:  31%|███▏      | 5/16 [00:01<00:02,  4.57it/s]
i=6/16 TRY=0 //tree=40 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0:  38%|███▊      | 6/16 [00:01<00:02,  4.48it/s]
i=7/16 TRY=0 //tree=40 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0:  38%|███▊      | 6/16 [00:01<00:02,  4.48it/s]
i=7/16 TRY=0 //tree=40 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0:  44%|████▍     | 7/16 [00:01<00:02,  4.45it/s]
i=8/16 TRY=0 //tree=40 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0:  44%|████▍     | 7/16 [00:01<00:02,  4.45it/s]
i=8/16 TRY=0 //tree=40 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0:  50%|█████     | 8/16 [00:01<00:01,  4.37it/s]
i=9/16 TRY=1 //tree=80 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0:  50%|█████     | 8/16 [00:01<00:01,  4.37it/s]
i=9/16 TRY=1 //tree=80 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0:  56%|█████▋    | 9/16 [00:02<00:01,  4.34it/s]
i=10/16 TRY=1 //tree=80 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0:  56%|█████▋    | 9/16 [00:02<00:01,  4.34it/s]
i=10/16 TRY=1 //tree=80 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0:  62%|██████▎   | 10/16 [00:02<00:01,  4.36it/s]
i=11/16 TRY=1 //tree=80 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0:  62%|██████▎   | 10/16 [00:02<00:01,  4.36it/s]
i=11/16 TRY=1 //tree=80 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0:  69%|██████▉   | 11/16 [00:02<00:01,  4.32it/s]
i=12/16 TRY=1 //tree=80 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0:  69%|██████▉   | 11/16 [00:02<00:01,  4.32it/s]
i=12/16 TRY=1 //tree=80 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0:  75%|███████▌  | 12/16 [00:02<00:00,  4.32it/s]
i=13/16 TRY=1 //tree=40 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0:  75%|███████▌  | 12/16 [00:02<00:00,  4.32it/s]
i=13/16 TRY=1 //tree=40 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0:  81%|████████▏ | 13/16 [00:02<00:00,  4.39it/s]
i=14/16 TRY=1 //tree=40 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0:  81%|████████▏ | 13/16 [00:02<00:00,  4.39it/s]
i=14/16 TRY=1 //tree=40 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0:  88%|████████▊ | 14/16 [00:03<00:00,  4.35it/s]
i=15/16 TRY=1 //tree=40 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0:  88%|████████▊ | 14/16 [00:03<00:00,  4.35it/s]
i=15/16 TRY=1 //tree=40 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0:  94%|█████████▍| 15/16 [00:03<00:00,  4.33it/s]
i=16/16 TRY=1 //tree=40 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0:  94%|█████████▍| 15/16 [00:03<00:00,  4.33it/s]
i=16/16 TRY=1 //tree=40 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0: 100%|██████████| 16/16 [00:03<00:00,  4.36it/s]
i=16/16 TRY=1 //tree=40 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0: 100%|██████████| 16/16 [00:03<00:00,  4.39it/s]

And the results.

df_dense = DataFrame(res)
df_dense["input"] = "dense"
df_sparse = DataFrame(res_sparse)
df_sparse["input"] = "sparse"
df = concat([df_dense, df_sparse], axis=0)
df.to_csv("plot_op_tree_ensemble_sparse.csv", index=False)
df.to_excel("plot_op_tree_ensemble_sparse.xlsx", index=False)
print(df.columns)
print(df.head(5))
Index(['average', 'deviation', 'min_exec', 'max_exec', 'repeat', 'number',
       'ttime', 'context_size', 'warmup_time', 'n_exp', 'n_exp_name',
       'short_name', 'TRY', 'name', 'parallel_tree', 'parallel_tree_N',
       'parallel_N', 'batch_size_tree', 'batch_size_rows', 'use_node3',
       'input'],
      dtype='object')
    average  deviation  min_exec  max_exec  repeat  number     ttime  context_size  warmup_time  n_exp  ...         short_name  TRY             name parallel_tree  parallel_tree_N  parallel_N  batch_size_tree  batch_size_rows  use_node3  input
0  0.000082   0.000012  0.000074  0.000114      10      10  0.000820            64     0.000752      0  ...         0,baseline  0.0         baseline           NaN              NaN         NaN              NaN              NaN        NaN  dense
1  0.000145   0.000157  0.000068  0.000614      10      10  0.001455            64     0.000968      0  ...  0,80,128,50,1,1,0  NaN  80,128,50,1,1,0          80.0            128.0        50.0              1.0              1.0        0.0  dense
2  0.000073   0.000003  0.000068  0.000082      10      10  0.000732            64     0.000719      1  ...  0,80,128,25,1,1,0  NaN  80,128,25,1,1,0          80.0            128.0        25.0              1.0              1.0        0.0  dense
3  0.000070   0.000003  0.000066  0.000075      10      10  0.000704            64     0.000721      2  ...   0,80,64,50,1,1,0  NaN   80,64,50,1,1,0          80.0             64.0        50.0              1.0              1.0        0.0  dense
4  0.000074   0.000006  0.000068  0.000086      10      10  0.000742            64     0.000832      3  ...   0,80,64,25,1,1,0  NaN   80,64,25,1,1,0          80.0             64.0        25.0              1.0              1.0        0.0  dense

[5 rows x 21 columns]

Sorting#

small_df = df.drop(
    [
        "min_exec",
        "max_exec",
        "repeat",
        "number",
        "context_size",
        "n_exp_name",
    ],
    axis=1,
).sort_values("average")
print(small_df.head(n=10))
     average  deviation     ttime  warmup_time  n_exp         short_name  TRY             name  parallel_tree  parallel_tree_N  parallel_N  batch_size_tree  batch_size_rows  use_node3  input
12  0.000046   0.000002  0.000459     0.000738     11   1,80,64,25,1,1,0  NaN   80,64,25,1,1,0           80.0             64.0        25.0              1.0              1.0        0.0  dense
7   0.000068   0.000004  0.000681     0.000768      6   0,40,64,50,1,1,0  NaN   40,64,50,1,1,0           40.0             64.0        50.0              1.0              1.0        0.0  dense
6   0.000069   0.000003  0.000686     0.000749      5  0,40,128,25,1,1,0  NaN  40,128,25,1,1,0           40.0            128.0        25.0              1.0              1.0        0.0  dense
8   0.000070   0.000005  0.000698     0.000789      7   0,40,64,25,1,1,0  NaN   40,64,25,1,1,0           40.0             64.0        25.0              1.0              1.0        0.0  dense
3   0.000070   0.000003  0.000704     0.000721      2   0,80,64,50,1,1,0  NaN   80,64,50,1,1,0           80.0             64.0        50.0              1.0              1.0        0.0  dense
2   0.000073   0.000003  0.000732     0.000719      1  0,80,128,25,1,1,0  NaN  80,128,25,1,1,0           80.0            128.0        25.0              1.0              1.0        0.0  dense
4   0.000074   0.000006  0.000742     0.000832      3   0,80,64,25,1,1,0  NaN   80,64,25,1,1,0           80.0             64.0        25.0              1.0              1.0        0.0  dense
0   0.000082   0.000012  0.000820     0.000752      0         0,baseline  0.0         baseline            NaN              NaN         NaN              NaN              NaN        NaN  dense
5   0.000118   0.000042  0.001181     0.000835      4  0,40,128,50,1,1,0  NaN  40,128,50,1,1,0           40.0            128.0        50.0              1.0              1.0        0.0  dense
11  0.000132   0.000004  0.001317     0.001155     10   1,80,64,50,1,1,0  NaN   80,64,50,1,1,0           80.0             64.0        50.0              1.0              1.0        0.0  dense

Worst#

print(small_df.tail(n=10))
     average  deviation     ttime  warmup_time  n_exp         short_name  TRY             name  parallel_tree  parallel_tree_N  parallel_N  batch_size_tree  batch_size_rows  use_node3   input
6   0.001115   0.000222  0.011151     0.006673      6   0,40,64,50,1,1,0  NaN   40,64,50,1,1,0           40.0             64.0        50.0              1.0              1.0        0.0  sparse
9   0.001115   0.000100  0.011151     0.005819      9  1,80,128,25,1,1,0  NaN  80,128,25,1,1,0           80.0            128.0        25.0              1.0              1.0        0.0  sparse
2   0.001155   0.000225  0.011554     0.006210      2   0,80,64,50,1,1,0  NaN   80,64,50,1,1,0           80.0             64.0        50.0              1.0              1.0        0.0  sparse
11  0.001181   0.000373  0.011808     0.005520     11   1,80,64,25,1,1,0  NaN   80,64,25,1,1,0           80.0             64.0        25.0              1.0              1.0        0.0  sparse
10  0.001187   0.000361  0.011867     0.007734     10   1,80,64,50,1,1,0  NaN   80,64,50,1,1,0           80.0             64.0        50.0              1.0              1.0        0.0  sparse
14  0.001190   0.000199  0.011895     0.005864     14   1,40,64,50,1,1,0  NaN   40,64,50,1,1,0           40.0             64.0        50.0              1.0              1.0        0.0  sparse
8   0.001190   0.000143  0.011899     0.006226      8  1,80,128,50,1,1,0  NaN  80,128,50,1,1,0           80.0            128.0        50.0              1.0              1.0        0.0  sparse
13  0.001194   0.000220  0.011941     0.007257     13  1,40,128,25,1,1,0  NaN  40,128,25,1,1,0           40.0            128.0        25.0              1.0              1.0        0.0  sparse
7   0.001232   0.000218  0.012324     0.005990      7   0,40,64,25,1,1,0  NaN   40,64,25,1,1,0           40.0             64.0        25.0              1.0              1.0        0.0  sparse
3   0.001244   0.000259  0.012437     0.005868      3   0,80,64,25,1,1,0  NaN   80,64,25,1,1,0           80.0             64.0        25.0              1.0              1.0        0.0  sparse

Plot#

skeys = ",".join(optim_params.keys())
title = f"TreeEnsemble tuning, n_tries={script_args.tries}\n{skeys}\nlower is better"
ax = hhistograms(df, title=title, keys=("input", "name"))
fig = ax.get_figure()
fig.savefig("plot_op_tree_ensemble_sparse.png")
TreeEnsemble tuning, n_tries=2 parallel_tree,parallel_tree_N,parallel_N,batch_size_tree,batch_size_rows,use_node3 lower is better

Total running time of the script: (0 minutes 6.780 seconds)

Gallery generated by Sphinx-Gallery