Note
Go to the end to download the full example code
TreeEnsemble optimization¶
The execution of a TreeEnsembleRegressor can lead to very different results depending on how the computation is parallelized. By trees, by rows, by both, for only one row, for a short batch of rows, a longer one. The implementation in onnxruntime does not let the user changed the predetermined settings but a custom kernel might. That’s what this example is measuring.
The default set of optimized parameters is very short and is meant to be executed fast. Many more parameters can be tried.
python plot_op_tree_ensemble_optim --scenario=LONG
To change the training parameters:
python plot_op_tree_ensemble_optim.py
--n_trees=100
--max_depth=10
--n_features=50
--batch_size=100000
Another example with a full list of parameters:
- python plot_op_tree_ensemble_optim.py
–n_trees=100 –max_depth=10 –n_features=50 –batch_size=100000 –tries=3 –scenario=CUSTOM –parallel_tree=80,40 –parallel_tree_N=128,64 –parallel_N=50,25 –batch_size_tree=1,2 –batch_size_rows=1,2 –use_node3=0
Another example:
python plot_op_tree_ensemble_optim.py
--n_trees=100 --n_features=10 --batch_size=10000 --max_depth=8 -s SHORT
import logging
import os
import timeit
from typing import Tuple
import numpy
import onnx
from onnx import ModelProto
from onnx.helper import make_graph, make_model
from onnx.reference import ReferenceEvaluator
from pandas import DataFrame
from sklearn.datasets import make_regression
from sklearn.ensemble import RandomForestRegressor
from skl2onnx import to_onnx
from onnxruntime import InferenceSession, SessionOptions
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
from onnx_extended.reference import CReferenceEvaluator
from onnx_extended.ortops.optim.cpu import get_ort_ext_libs
from onnx_extended.ortops.optim.optimize import (
change_onnx_operator_domain,
get_node_attribute,
optimize_model,
)
from onnx_extended.tools.onnx_nodes import multiply_tree
from onnx_extended.args import get_parsed_args
from onnx_extended.ext_test_case import unit_test_going
from onnx_extended.plotting.benchmark import hhistograms
logging.getLogger("matplotlib.font_manager").setLevel(logging.ERROR)
script_args = get_parsed_args(
"plot_op_tree_ensemble_optim",
description=__doc__,
scenarios={
"SHORT": "short optimization (default)",
"LONG": "test more options",
"CUSTOM": "use values specified by the command line",
},
n_features=(2 if unit_test_going() else 5, "number of features to generate"),
n_trees=(3 if unit_test_going() else 10, "number of trees to train"),
max_depth=(2 if unit_test_going() else 5, "max_depth"),
batch_size=(1000 if unit_test_going() else 10000, "batch size"),
parallel_tree=("80,160,40", "values to try for parallel_tree"),
parallel_tree_N=("256,128,64", "values to try for parallel_tree_N"),
parallel_N=("100,50,25", "values to try for parallel_N"),
batch_size_tree=("2,4,8", "values to try for batch_size_tree"),
batch_size_rows=("2,4,8", "values to try for batch_size_rows"),
use_node3=("0,1", "values to try for use_node3"),
expose="",
n_jobs=("-1", "number of jobs to train the RandomForestRegressor"),
)
Training a model¶
def train_model(
batch_size: int, n_features: int, n_trees: int, max_depth: int
) -> Tuple[str, numpy.ndarray, numpy.ndarray]:
filename = f"plot_op_tree_ensemble_optim-f{n_features}-{n_trees}-d{max_depth}.onnx"
if not os.path.exists(filename):
X, y = make_regression(
batch_size + max(batch_size, 2 ** (max_depth + 1)),
n_features=n_features,
n_targets=1,
)
print(f"Training to get {filename!r} with X.shape={X.shape}")
X, y = X.astype(numpy.float32), y.astype(numpy.float32)
# To be faster, we train only 1 tree.
model = RandomForestRegressor(
1, max_depth=max_depth, verbose=2, n_jobs=int(script_args.n_jobs)
)
model.fit(X[:-batch_size], y[:-batch_size])
onx = to_onnx(model, X[:1], target_opset={"": 18, "ai.onnx.ml": 3})
# And wd multiply the trees.
node = multiply_tree(onx.graph.node[0], n_trees)
onx = make_model(
make_graph([node], onx.graph.name, onx.graph.input, onx.graph.output),
domain=onx.domain,
opset_imports=onx.opset_import,
ir_version=onx.ir_version,
)
with open(filename, "wb") as f:
f.write(onx.SerializeToString())
else:
X, y = make_regression(batch_size, n_features=n_features, n_targets=1)
X, y = X.astype(numpy.float32), y.astype(numpy.float32)
Xb, yb = X[-batch_size:].copy(), y[-batch_size:].copy()
return filename, Xb, yb
batch_size = script_args.batch_size
n_features = script_args.n_features
n_trees = script_args.n_trees
max_depth = script_args.max_depth
print(f"batch_size={batch_size}")
print(f"n_features={n_features}")
print(f"n_trees={n_trees}")
print(f"max_depth={max_depth}")
batch_size=10000
n_features=5
n_trees=10
max_depth=5
training
filename, Xb, yb = train_model(batch_size, n_features, n_trees, max_depth)
print(f"Xb.shape={Xb.shape}")
print(f"yb.shape={yb.shape}")
Training to get 'plot_op_tree_ensemble_optim-f5-10-d5.onnx' with X.shape=(20000, 5)
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
building tree 1 of 1
[Parallel(n_jobs=-1)]: Done 1 out of 1 | elapsed: 0.1s finished
Xb.shape=(10000, 5)
yb.shape=(10000,)
Rewrite the onnx file to use a different kernel¶
The custom kernel is mapped to a custom operator with the same name the attributes and domain = “onnx_extended.ortops.optim.cpu”. We call a function to do that replacement. First the current model.
opset: domain='ai.onnx.ml' version=1
opset: domain='' version=18
input: name='X' type=dtype('float32') shape=['', 5]
TreeEnsembleRegressor(X, n_targets=1, nodes_falsenodeids=630:[32,17,10...62,0,0], nodes_featureids=630:[2,0,1...0,2,0], nodes_hitrates=630:[1.0,1.0...1.0,1.0], nodes_missing_value_tracks_true=630:[0,0,0...0,0,0], nodes_modes=630:[b'BRANCH_LEQ',b'BRANCH_LEQ'...b'LEAF',b'LEAF'], nodes_nodeids=630:[0,1,2...60,61,62], nodes_treeids=630:[0,0,0...9,9,9], nodes_truenodeids=630:[1,2,3...61,0,0], nodes_values=630:[-0.2562659978866577,-0.02141805738210678...-1.2375441789627075,1.2125707864761353], post_transform=b'NONE', target_ids=320:[0,0,0...0,0,0], target_nodeids=320:[5,6,8...59,61,62], target_treeids=320:[0,0,0...9,9,9], target_weights=320:[-304.95086669921875,-216.4969482421875...200.1363983154297,282.164794921875]) -> variable
output: name='variable' type=dtype('float32') shape=['', 1]
And then the modified model.
def transform_model(model, **kwargs):
onx = ModelProto()
onx.ParseFromString(model.SerializeToString())
att = get_node_attribute(onx.graph.node[0], "nodes_modes")
modes = ",".join(map(lambda s: s.decode("ascii"), att.strings)).replace(
"BRANCH_", ""
)
return change_onnx_operator_domain(
onx,
op_type="TreeEnsembleRegressor",
op_domain="ai.onnx.ml",
new_op_domain="onnx_extended.ortops.optim.cpu",
nodes_modes=modes,
**kwargs,
)
print("Tranform model to add a custom node.")
onx_modified = transform_model(onx)
print(f"Save into {filename + 'modified.onnx'!r}.")
with open(filename + "modified.onnx", "wb") as f:
f.write(onx_modified.SerializeToString())
print("done.")
print(onnx_simple_text_plot(onx_modified))
Tranform model to add a custom node.
Save into 'plot_op_tree_ensemble_optim-f5-10-d5.onnxmodified.onnx'.
done.
opset: domain='ai.onnx.ml' version=1
opset: domain='' version=18
opset: domain='onnx_extended.ortops.optim.cpu' version=1
input: name='X' type=dtype('float32') shape=['', 5]
TreeEnsembleRegressor[onnx_extended.ortops.optim.cpu](X, nodes_modes=b'LEQ,LEQ,LEQ,LEQ,LEQ,LEAF,LEAF,LEQ,LEAF...LEAF,LEAF', n_targets=1, nodes_falsenodeids=630:[32,17,10...62,0,0], nodes_featureids=630:[2,0,1...0,2,0], nodes_hitrates=630:[1.0,1.0...1.0,1.0], nodes_missing_value_tracks_true=630:[0,0,0...0,0,0], nodes_nodeids=630:[0,1,2...60,61,62], nodes_treeids=630:[0,0,0...9,9,9], nodes_truenodeids=630:[1,2,3...61,0,0], nodes_values=630:[-0.2562659978866577,-0.02141805738210678...-1.2375441789627075,1.2125707864761353], post_transform=b'NONE', target_ids=320:[0,0,0...0,0,0], target_nodeids=320:[5,6,8...59,61,62], target_treeids=320:[0,0,0...9,9,9], target_weights=320:[-304.95086669921875,-216.4969482421875...200.1363983154297,282.164794921875]) -> variable
output: name='variable' type=dtype('float32') shape=['', 1]
Comparing onnxruntime and the custom kernel¶
print(f"Loading {filename!r}")
sess_ort = InferenceSession(filename, providers=["CPUExecutionProvider"])
r = get_ort_ext_libs()
print(f"Creating SessionOptions with {r!r}")
opts = SessionOptions()
if r is not None:
opts.register_custom_ops_library(r[0])
print(f"Loading modified {filename!r}")
sess_cus = InferenceSession(
onx_modified.SerializeToString(), opts, providers=["CPUExecutionProvider"]
)
print(f"Running once with shape {Xb.shape}.")
base = sess_ort.run(None, {"X": Xb})[0]
print(f"Running modified with shape {Xb.shape}.")
got = sess_cus.run(None, {"X": Xb})[0]
print("done.")
Loading 'plot_op_tree_ensemble_optim-f5-10-d5.onnx'
Creating SessionOptions with ['/home/xadupre/github/onnx-extended/onnx_extended/ortops/optim/cpu/libortops_optim_cpu.so']
Loading modified 'plot_op_tree_ensemble_optim-f5-10-d5.onnx'
Running once with shape (10000, 5).
Running modified with shape (10000, 5).
done.
Discrepancies?
Discrepancies: max=2.4154769562301226e-07, mean=5.230815602885741e-08 (A=1516.10205078125)
Simple verification¶
Baseline with onnxruntime.
t1 = timeit.timeit(lambda: sess_ort.run(None, {"X": Xb}), number=50)
print(f"baseline: {t1}")
baseline: 0.0604920999994647
The custom implementation.
t2 = timeit.timeit(lambda: sess_cus.run(None, {"X": Xb}), number=50)
print(f"new time: {t2}")
new time: 0.0373423000000912
The same implementation but ran from the onnx python backend.
ref = CReferenceEvaluator(filename)
ref.run(None, {"X": Xb})
t3 = timeit.timeit(lambda: ref.run(None, {"X": Xb}), number=50)
print(f"CReferenceEvaluator: {t3}")
CReferenceEvaluator: 0.04880120000052557
The python implementation but from the onnx python backend.
ReferenceEvaluator: 4.576077199999418 (only 5 times instead of 50)
Time for comparison¶
The custom kernel supports the same attributes as TreeEnsembleRegressor plus new ones to tune the parallelization. They can be seen in tree_ensemble.cc. Let’s try out many possibilities. The default values are the first ones.
if unit_test_going():
optim_params = dict(
parallel_tree=[40], # default is 80
parallel_tree_N=[128], # default is 128
parallel_N=[50, 25], # default is 50
batch_size_tree=[1], # default is 1
batch_size_rows=[1], # default is 1
use_node3=[0], # default is 0
)
elif script_args.scenario in (None, "SHORT"):
optim_params = dict(
parallel_tree=[80, 40], # default is 80
parallel_tree_N=[128, 64], # default is 128
parallel_N=[50, 25], # default is 50
batch_size_tree=[1], # default is 1
batch_size_rows=[1], # default is 1
use_node3=[0], # default is 0
)
elif script_args.scenario == "LONG":
optim_params = dict(
parallel_tree=[80, 160, 40],
parallel_tree_N=[256, 128, 64],
parallel_N=[100, 50, 25],
batch_size_tree=[1, 2, 4, 8],
batch_size_rows=[1, 2, 4, 8],
use_node3=[0, 1],
)
elif script_args.scenario == "CUSTOM":
optim_params = dict(
parallel_tree=list(int(i) for i in script_args.parallel_tree.split(",")),
parallel_tree_N=list(int(i) for i in script_args.parallel_tree_N.split(",")),
parallel_N=list(int(i) for i in script_args.parallel_N.split(",")),
batch_size_tree=list(int(i) for i in script_args.batch_size_tree.split(",")),
batch_size_rows=list(int(i) for i in script_args.batch_size_rows.split(",")),
use_node3=list(int(i) for i in script_args.use_node3.split(",")),
)
else:
raise ValueError(
f"Unknown scenario {script_args.scenario!r}, use --help to get them."
)
cmds = []
for att, value in optim_params.items():
cmds.append(f"--{att}={','.join(map(str, value))}")
print("Full list of optimization parameters:")
print(" ".join(cmds))
Full list of optimization parameters:
--parallel_tree=80,40 --parallel_tree_N=128,64 --parallel_N=50,25 --batch_size_tree=1 --batch_size_rows=1 --use_node3=0
Then the optimization.
def create_session(onx):
opts = SessionOptions()
r = get_ort_ext_libs()
if r is None:
raise RuntimeError("No custom implementation available.")
opts.register_custom_ops_library(r[0])
return InferenceSession(
onx.SerializeToString(), opts, providers=["CPUExecutionProvider"]
)
res = optimize_model(
onx,
feeds={"X": Xb},
transform=transform_model,
session=create_session,
baseline=lambda onx: InferenceSession(
onx.SerializeToString(), providers=["CPUExecutionProvider"]
),
params=optim_params,
verbose=True,
number=script_args.number,
repeat=script_args.repeat,
warmup=script_args.warmup,
sleep=script_args.sleep,
n_tries=script_args.tries,
)
0%| | 0/16 [00:00<?, ?it/s]
i=1/16 TRY=0 //tree=80 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0: 0%| | 0/16 [00:00<?, ?it/s]
i=1/16 TRY=0 //tree=80 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0: 6%|▋ | 1/16 [00:00<00:07, 1.93it/s]
i=2/16 TRY=0 //tree=80 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 6%|▋ | 1/16 [00:00<00:07, 1.93it/s]
i=2/16 TRY=0 //tree=80 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 12%|█▎ | 2/16 [00:00<00:04, 2.90it/s]
i=3/16 TRY=0 //tree=80 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 12%|█▎ | 2/16 [00:00<00:04, 2.90it/s]
i=3/16 TRY=0 //tree=80 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 19%|█▉ | 3/16 [00:01<00:04, 3.20it/s]
i=4/16 TRY=0 //tree=80 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 19%|█▉ | 3/16 [00:01<00:04, 3.20it/s]
i=4/16 TRY=0 //tree=80 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 25%|██▌ | 4/16 [00:01<00:03, 3.46it/s]
i=5/16 TRY=0 //tree=40 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 25%|██▌ | 4/16 [00:01<00:03, 3.46it/s]
i=5/16 TRY=0 //tree=40 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 31%|███▏ | 5/16 [00:01<00:02, 3.71it/s]
i=6/16 TRY=0 //tree=40 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 31%|███▏ | 5/16 [00:01<00:02, 3.71it/s]
i=6/16 TRY=0 //tree=40 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 38%|███▊ | 6/16 [00:01<00:02, 3.72it/s]
i=7/16 TRY=0 //tree=40 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 38%|███▊ | 6/16 [00:01<00:02, 3.72it/s]
i=7/16 TRY=0 //tree=40 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 44%|████▍ | 7/16 [00:02<00:02, 3.78it/s]
i=8/16 TRY=0 //tree=40 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 44%|████▍ | 7/16 [00:02<00:02, 3.78it/s]
i=8/16 TRY=0 //tree=40 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 50%|█████ | 8/16 [00:02<00:02, 3.79it/s]
i=9/16 TRY=1 //tree=80 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 50%|█████ | 8/16 [00:02<00:02, 3.79it/s]
i=9/16 TRY=1 //tree=80 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 56%|█████▋ | 9/16 [00:02<00:01, 3.77it/s]
i=10/16 TRY=1 //tree=80 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 56%|█████▋ | 9/16 [00:02<00:01, 3.77it/s]
i=10/16 TRY=1 //tree=80 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 62%|██████▎ | 10/16 [00:02<00:01, 3.94it/s]
i=11/16 TRY=1 //tree=80 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 62%|██████▎ | 10/16 [00:02<00:01, 3.94it/s]
i=11/16 TRY=1 //tree=80 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 69%|██████▉ | 11/16 [00:03<00:01, 3.86it/s]
i=12/16 TRY=1 //tree=80 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 69%|██████▉ | 11/16 [00:03<00:01, 3.86it/s]
i=12/16 TRY=1 //tree=80 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 75%|███████▌ | 12/16 [00:03<00:01, 3.88it/s]
i=13/16 TRY=1 //tree=40 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 75%|███████▌ | 12/16 [00:03<00:01, 3.88it/s]
i=13/16 TRY=1 //tree=40 //tree_N=128 //N=50 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 81%|████████▏ | 13/16 [00:03<00:00, 3.87it/s]
i=14/16 TRY=1 //tree=40 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 81%|████████▏ | 13/16 [00:03<00:00, 3.87it/s]
i=14/16 TRY=1 //tree=40 //tree_N=128 //N=25 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 88%|████████▊ | 14/16 [00:03<00:00, 3.78it/s]
i=15/16 TRY=1 //tree=40 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 88%|████████▊ | 14/16 [00:03<00:00, 3.78it/s]
i=15/16 TRY=1 //tree=40 //tree_N=64 //N=50 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 94%|█████████▍| 15/16 [00:04<00:00, 3.84it/s]
i=16/16 TRY=1 //tree=40 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 94%|█████████▍| 15/16 [00:04<00:00, 3.84it/s]
i=16/16 TRY=1 //tree=40 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 100%|██████████| 16/16 [00:04<00:00, 3.84it/s]
i=16/16 TRY=1 //tree=40 //tree_N=64 //N=25 bs_tree=1 batch_size_rows=1 n3=0 ~=1.84x: 100%|██████████| 16/16 [00:04<00:00, 3.67it/s]
And the results.
df = DataFrame(res)
df.to_csv("plot_op_tree_ensemble_optim.csv", index=False)
df.to_excel("plot_op_tree_ensemble_optim.xlsx", index=False)
print(df.columns)
print(df.head(5))
Index(['average', 'deviation', 'min_exec', 'max_exec', 'repeat', 'number',
'ttime', 'context_size', 'warmup_time', 'n_exp', 'n_exp_name',
'short_name', 'TRY', 'name', 'parallel_tree', 'parallel_tree_N',
'parallel_N', 'batch_size_tree', 'batch_size_rows', 'use_node3'],
dtype='object')
average deviation min_exec max_exec repeat number ttime ... name parallel_tree parallel_tree_N parallel_N batch_size_tree batch_size_rows use_node3
0 0.001890 0.000238 0.001652 0.002443 10 10 0.018897 ... baseline NaN NaN NaN NaN NaN NaN
1 0.001028 0.000132 0.000930 0.001334 10 10 0.010275 ... 80,128,50,1,1,0 80.0 128.0 50.0 1.0 1.0 0.0
2 0.001132 0.000252 0.000870 0.001699 10 10 0.011315 ... 80,128,25,1,1,0 80.0 128.0 25.0 1.0 1.0 0.0
3 0.001557 0.000779 0.000672 0.002900 10 10 0.015567 ... 80,64,50,1,1,0 80.0 64.0 50.0 1.0 1.0 0.0
4 0.001373 0.000648 0.000602 0.002341 10 10 0.013733 ... 80,64,25,1,1,0 80.0 64.0 25.0 1.0 1.0 0.0
[5 rows x 20 columns]
Sorting¶
small_df = df.drop(
[
"min_exec",
"max_exec",
"repeat",
"number",
"context_size",
"n_exp_name",
],
axis=1,
).sort_values("average")
print(small_df.head(n=10))
average deviation ttime warmup_time n_exp short_name TRY name parallel_tree parallel_tree_N parallel_N batch_size_tree batch_size_rows use_node3
1 0.001028 0.000132 0.010275 0.005814 0 0,80,128,50,1,1,0 NaN 80,128,50,1,1,0 80.0 128.0 50.0 1.0 1.0 0.0
10 0.001079 0.000608 0.010794 0.015019 9 1,80,128,25,1,1,0 NaN 80,128,25,1,1,0 80.0 128.0 25.0 1.0 1.0 0.0
5 0.001085 0.000659 0.010851 0.022635 4 0,40,128,50,1,1,0 NaN 40,128,50,1,1,0 40.0 128.0 50.0 1.0 1.0 0.0
2 0.001132 0.000252 0.011315 0.006750 1 0,80,128,25,1,1,0 NaN 80,128,25,1,1,0 80.0 128.0 25.0 1.0 1.0 0.0
7 0.001332 0.000658 0.013321 0.018723 6 0,40,64,50,1,1,0 NaN 40,64,50,1,1,0 40.0 64.0 50.0 1.0 1.0 0.0
12 0.001334 0.000851 0.013342 0.012163 11 1,80,64,25,1,1,0 NaN 80,64,25,1,1,0 80.0 64.0 25.0 1.0 1.0 0.0
15 0.001334 0.000845 0.013344 0.013864 14 1,40,64,50,1,1,0 NaN 40,64,50,1,1,0 40.0 64.0 50.0 1.0 1.0 0.0
4 0.001373 0.000648 0.013733 0.010307 3 0,80,64,25,1,1,0 NaN 80,64,25,1,1,0 80.0 64.0 25.0 1.0 1.0 0.0
13 0.001407 0.000771 0.014072 0.014144 12 1,40,128,50,1,1,0 NaN 40,128,50,1,1,0 40.0 128.0 50.0 1.0 1.0 0.0
8 0.001412 0.000957 0.014121 0.012743 7 0,40,64,25,1,1,0 NaN 40,64,25,1,1,0 40.0 64.0 25.0 1.0 1.0 0.0
Worst¶
print(small_df.tail(n=10))
average deviation ttime warmup_time n_exp short_name TRY name parallel_tree parallel_tree_N parallel_N batch_size_tree batch_size_rows use_node3
13 0.001407 0.000771 0.014072 0.014144 12 1,40,128,50,1,1,0 NaN 40,128,50,1,1,0 40.0 128.0 50.0 1.0 1.0 0.0
8 0.001412 0.000957 0.014121 0.012743 7 0,40,64,25,1,1,0 NaN 40,64,25,1,1,0 40.0 64.0 25.0 1.0 1.0 0.0
16 0.001438 0.001196 0.014382 0.011125 15 1,40,64,25,1,1,0 NaN 40,64,25,1,1,0 40.0 64.0 25.0 1.0 1.0 0.0
6 0.001495 0.000952 0.014954 0.012180 5 0,40,128,25,1,1,0 NaN 40,128,25,1,1,0 40.0 128.0 25.0 1.0 1.0 0.0
9 0.001500 0.000992 0.014999 0.014506 8 1,80,128,50,1,1,0 NaN 80,128,50,1,1,0 80.0 128.0 50.0 1.0 1.0 0.0
11 0.001531 0.001024 0.015314 0.011912 10 1,80,64,50,1,1,0 NaN 80,64,50,1,1,0 80.0 64.0 50.0 1.0 1.0 0.0
3 0.001557 0.000779 0.015567 0.010704 2 0,80,64,50,1,1,0 NaN 80,64,50,1,1,0 80.0 64.0 50.0 1.0 1.0 0.0
14 0.001607 0.001029 0.016072 0.012801 13 1,40,128,25,1,1,0 NaN 40,128,25,1,1,0 40.0 128.0 25.0 1.0 1.0 0.0
0 0.001890 0.000238 0.018897 0.008754 0 0,baseline 0.0 baseline NaN NaN NaN NaN NaN NaN
17 0.001966 0.001533 0.019659 0.024567 0 1,baseline 1.0 baseline NaN NaN NaN NaN NaN NaN
Plot¶
skeys = ",".join(optim_params.keys())
title = f"TreeEnsemble tuning, n_tries={script_args.tries}\n{skeys}\nlower is better"
ax = hhistograms(df, title=title, keys=("name",))
fig = ax.get_figure()
fig.savefig("plot_op_tree_ensemble_optim.png")
Total running time of the script: (0 minutes 11.191 seconds)