{ "cells": [ { "cell_type": "markdown", "id": "59a065e8", "metadata": {}, "source": [ "# NeuralTreeNet et coût\n", "\n", "La classe *NeuralTreeNet* convertit un arbre de décision en réseau de neurones. Si la conversion n'est pas exacte mais elle permet d'obtenir un modèle différentiable et apprenable avec un algorithme d'optimisation à base de gradient. Ce notebook compare le temps d'éxécution entre un arbre et le réseau de neurones." ] }, { "cell_type": "code", "execution_count": 1, "id": "e6ad71f6", "metadata": {}, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "id": "52d5e7f8", "metadata": {}, "source": [ "## Jeux de données\n", "\n", "On construit un jeu de données aléatoire." ] }, { "cell_type": "code", "execution_count": 2, "id": "0abef0bf", "metadata": {}, "outputs": [], "source": [ "import numpy\n", "\n", "X = numpy.random.randn(10000, 10)\n", "y = X.sum(axis=1) / X.shape[1]\n", "X = X.astype(numpy.float64)\n", "y = y.astype(numpy.float64)" ] }, { "cell_type": "code", "execution_count": 3, "id": "8850b4e7", "metadata": {}, "outputs": [], "source": [ "middle = X.shape[0] // 2\n", "X_train, X_test = X[:middle], X[middle:]\n", "y_train, y_test = y[:middle], y[middle:]" ] }, { "cell_type": "markdown", "id": "12c4a84c", "metadata": {}, "source": [ "## Caler un arbre de décision" ] }, { "cell_type": "code", "execution_count": 4, "id": "1c0b0169", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0.6225001966466359, 0.37938295559354807)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.tree import DecisionTreeRegressor\n", "\n", "tree = DecisionTreeRegressor(max_depth=7)\n", "tree.fit(X_train, y_train)\n", "tree.score(X_train, y_train), tree.score(X_test, y_test)" ] }, { "cell_type": "code", "execution_count": 5, "id": "6b158b44", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.37938295559354807" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.metrics import r2_score\n", "\n", "r2_score(y_test, tree.predict(X_test))" ] }, { "cell_type": "markdown", "id": "36db83ef", "metadata": {}, "source": [ "Covnersion de l'arbre en réseau de neurones" ] }, { "cell_type": "code", "execution_count": 6, "id": "60e3e6ac", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0
average absolute error0.208776
max absolute error1.427806
\n", "
" ], "text/plain": [ " 0\n", "average absolute error 0.208776\n", "max absolute error 1.427806" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pandas import DataFrame\n", "from mlstatpy.ml.neural_tree import NeuralTreeNet, NeuralTreeNetRegressor\n", "\n", "xe = X_test.astype(numpy.float32)\n", "expected = tree.predict(xe)\n", "\n", "nn = NeuralTreeNetRegressor(NeuralTreeNet.create_from_tree(tree, arch=\"compact\"))\n", "got = nn.predict(xe)\n", "me = numpy.abs(got - expected).mean()\n", "mx = numpy.abs(got - expected).max()\n", "DataFrame([{\"average absolute error\": me, \"max absolute error\": mx}]).T" ] }, { "cell_type": "markdown", "id": "559f0a25", "metadata": {}, "source": [ "La conversion est loin d'être parfaite. La raison vient du fait que les fonctions de seuil sont approchées par des fonctions sigmoïdes. Il suffit d'une erreur minime pour que la décision prenne un chemin différent dans l'arbre et soit complètement différente." ] }, { "cell_type": "markdown", "id": "c1ad28cf", "metadata": {}, "source": [ "## Conversion au format ONNX" ] }, { "cell_type": "code", "execution_count": 7, "id": "b01518ec", "metadata": {}, "outputs": [], "source": [ "from skl2onnx import to_onnx\n", "\n", "onx_tree = to_onnx(tree, X[:1].astype(numpy.float32))\n", "onx_nn = to_onnx(nn, X[:1].astype(numpy.float32))" ] }, { "cell_type": "markdown", "id": "f59994a5", "metadata": {}, "source": [ "Le réseau de neurones peut être représenté comme suit." ] }, { "cell_type": "code", "execution_count": 8, "id": "a33fedcb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "opset: domain='' version=21\n", "input: name='X' type=dtype('float32') shape=['', 10]\n", "init: name='Ma_MatMulcst' type=dtype('float32') shape=(10, 127)\n", "init: name='Ad_Addcst' type=dtype('float32') shape=(127,)\n", "init: name='Mu_Mulcst' type=dtype('float32') shape=(1,) -- array([4.], dtype=float32)\n", "init: name='Ma_MatMulcst1' type=dtype('float32') shape=(127, 128)\n", "init: name='Ad_Addcst1' type=dtype('float32') shape=(128,)\n", "init: name='Ma_MatMulcst2' type=dtype('float32') shape=(128, 1)\n", "init: name='Ad_Addcst2' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)\n", "MatMul(X, Ma_MatMulcst) -> Ma_Y02\n", " Add(Ma_Y02, Ad_Addcst) -> Ad_C02\n", " Mul(Ad_C02, Mu_Mulcst) -> Mu_C01\n", " Sigmoid(Mu_C01) -> Si_Y01\n", " MatMul(Si_Y01, Ma_MatMulcst1) -> Ma_Y01\n", " Add(Ma_Y01, Ad_Addcst1) -> Ad_C01\n", " Mul(Ad_C01, Mu_Mulcst) -> Mu_C0\n", " Sigmoid(Mu_C0) -> Si_Y0\n", " MatMul(Si_Y0, Ma_MatMulcst2) -> Ma_Y0\n", " Add(Ma_Y0, Ad_Addcst2) -> Ad_C0\n", " Identity(Ad_C0) -> variable\n", "output: name='variable' type=dtype('float32') shape=['', 1]\n" ] } ], "source": [ "from onnx_array_api.plotting.text_plot import onnx_simple_text_plot\n", "\n", "print(onnx_simple_text_plot(onx_nn))" ] }, { "cell_type": "markdown", "id": "857f2e42", "metadata": {}, "source": [ "## Temps de calcul des prédictions" ] }, { "cell_type": "code", "execution_count": 9, "id": "7c810819", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "584 μs ± 16.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], "source": [ "from onnxruntime import InferenceSession\n", "\n", "oinf_tree = InferenceSession(onx_tree.SerializeToString())\n", "oinf_nn = InferenceSession(onx_nn.SerializeToString())\n", "\n", "%timeit tree.predict(xe)" ] }, { "cell_type": "code", "execution_count": 10, "id": "51f6c958", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "48.4 μs ± 1.16 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "%timeit oinf_tree.run(None, {'X': xe})" ] }, { "cell_type": "code", "execution_count": 11, "id": "ab1ff3a8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.28 ms ± 97.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], "source": [ "%timeit oinf_nn.run(None, {'X': xe})" ] }, { "cell_type": "markdown", "id": "5d8ecaa5", "metadata": {}, "source": [ "Le temps de calcul est nettement plus long pour le réseau de neurones. Si l'arbre de décision a une profondeur de *d*, l'arbre de décision va faire exactement *d* comparaisons. Le réseau de neurones quant à lui évalue tous les seuils pour chaque prédiction, soit $2^d$. Vérifions cela en faisant variable la profondeur." ] }, { "cell_type": "markdown", "id": "b27675e4", "metadata": {}, "source": [ "## Temps de calcul en fonction de la profondeur" ] }, { "cell_type": "code", "execution_count": 12, "id": "ecef383a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 8/8 [00:04<00:00, 1.63it/s]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
averagedeviationmin_execmax_execrepeatnumberttimecontext_sizedexp
00.0002070.0000450.0001530.00031720200.004147642skl
10.0001510.0002460.0000310.00082510100.001515642onx_tree
20.0001780.0000930.0001190.00037110100.001781642onx_nn
30.0002490.0000360.0002200.00036020200.004980643skl
40.0003120.0001560.0001130.00066110100.003117643onx_tree
50.0003520.0002040.0001820.00083110100.003523643onx_nn
60.0003390.0000730.0002570.00048720200.006775644skl
70.0003370.0004230.0000590.00153710100.003368644onx_tree
80.0006190.0003540.0002210.00132010100.006194644onx_nn
90.0003590.0000380.0003090.00045320200.007171645skl
100.0004730.0005650.0000640.00192310100.004729645onx_tree
110.0011970.0009440.0003090.00352910100.011973645onx_nn
120.0003860.0000220.0003590.00043920200.007715646skl
130.0007930.0007700.0000970.00244510100.007926646onx_tree
140.0015210.0009190.0006520.00382010100.015207646onx_nn
150.0004290.0000240.0004040.00049420200.008579647skl
160.0006580.0006620.0002070.00248410100.006575647onx_tree
170.0029250.0027700.0014890.01104810100.029251647onx_nn
180.0005080.0000590.0004520.00073320200.010157648skl
190.0012350.0012080.0001210.00384210100.012347648onx_tree
200.0046270.0042390.0029620.01730010100.046271648onx_nn
210.0005580.0000450.0004980.00070020200.011152649skl
220.0007450.0005400.0001380.00216610100.007449649onx_tree
230.0111270.0048560.0090140.02566710100.111265649onx_nn
\n", "
" ], "text/plain": [ " average deviation min_exec max_exec repeat number ttime \\\n", "0 0.000207 0.000045 0.000153 0.000317 20 20 0.004147 \n", "1 0.000151 0.000246 0.000031 0.000825 10 10 0.001515 \n", "2 0.000178 0.000093 0.000119 0.000371 10 10 0.001781 \n", "3 0.000249 0.000036 0.000220 0.000360 20 20 0.004980 \n", "4 0.000312 0.000156 0.000113 0.000661 10 10 0.003117 \n", "5 0.000352 0.000204 0.000182 0.000831 10 10 0.003523 \n", "6 0.000339 0.000073 0.000257 0.000487 20 20 0.006775 \n", "7 0.000337 0.000423 0.000059 0.001537 10 10 0.003368 \n", "8 0.000619 0.000354 0.000221 0.001320 10 10 0.006194 \n", "9 0.000359 0.000038 0.000309 0.000453 20 20 0.007171 \n", "10 0.000473 0.000565 0.000064 0.001923 10 10 0.004729 \n", "11 0.001197 0.000944 0.000309 0.003529 10 10 0.011973 \n", "12 0.000386 0.000022 0.000359 0.000439 20 20 0.007715 \n", "13 0.000793 0.000770 0.000097 0.002445 10 10 0.007926 \n", "14 0.001521 0.000919 0.000652 0.003820 10 10 0.015207 \n", "15 0.000429 0.000024 0.000404 0.000494 20 20 0.008579 \n", "16 0.000658 0.000662 0.000207 0.002484 10 10 0.006575 \n", "17 0.002925 0.002770 0.001489 0.011048 10 10 0.029251 \n", "18 0.000508 0.000059 0.000452 0.000733 20 20 0.010157 \n", "19 0.001235 0.001208 0.000121 0.003842 10 10 0.012347 \n", "20 0.004627 0.004239 0.002962 0.017300 10 10 0.046271 \n", "21 0.000558 0.000045 0.000498 0.000700 20 20 0.011152 \n", "22 0.000745 0.000540 0.000138 0.002166 10 10 0.007449 \n", "23 0.011127 0.004856 0.009014 0.025667 10 10 0.111265 \n", "\n", " context_size d exp \n", "0 64 2 skl \n", "1 64 2 onx_tree \n", "2 64 2 onx_nn \n", "3 64 3 skl \n", "4 64 3 onx_tree \n", "5 64 3 onx_nn \n", "6 64 4 skl \n", "7 64 4 onx_tree \n", "8 64 4 onx_nn \n", "9 64 5 skl \n", "10 64 5 onx_tree \n", "11 64 5 onx_nn \n", "12 64 6 skl \n", "13 64 6 onx_tree \n", "14 64 6 onx_nn \n", "15 64 7 skl \n", "16 64 7 onx_tree \n", "17 64 7 onx_nn \n", "18 64 8 skl \n", "19 64 8 onx_tree \n", "20 64 8 onx_nn \n", "21 64 9 skl \n", "22 64 9 onx_tree \n", "23 64 9 onx_nn " ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from tqdm import tqdm\n", "from onnx_array_api.ext_test_case import measure_time\n", "\n", "data = []\n", "for d in tqdm(range(2, 10)):\n", " tree = DecisionTreeRegressor(max_depth=d)\n", " tree.fit(X_train, y_train)\n", " obs = measure_time(lambda tree=tree: tree.predict(xe), number=20, repeat=20)\n", " obs.update(dict(d=d, exp=\"skl\"))\n", " data.append(obs)\n", "\n", " nn = NeuralTreeNetRegressor(NeuralTreeNet.create_from_tree(tree, arch=\"compact\"))\n", "\n", " onx_tree = to_onnx(tree, X[:1].astype(numpy.float32))\n", " onx_nn = to_onnx(nn, X[:1].astype(numpy.float32))\n", " oinf_tree = InferenceSession(\n", " onx_tree.SerializePartialToString(), providers=[\"CPUExecutionProvider\"]\n", " )\n", " oinf_nn = InferenceSession(\n", " onx_nn.SerializePartialToString(), providers=[\"CPUExecutionProvider\"]\n", " )\n", "\n", " obs = measure_time(\n", " lambda oinf_tree=oinf_tree: oinf_tree.run(None, {\"X\": xe}), number=10, repeat=10\n", " )\n", " obs.update(dict(d=d, exp=\"onx_tree\"))\n", " data.append(obs)\n", "\n", " obs = measure_time(\n", " lambda oinf_nn=oinf_nn: oinf_nn.run(None, {\"X\": xe}), number=10, repeat=10\n", " )\n", " obs.update(dict(d=d, exp=\"onx_nn\"))\n", " data.append(obs)\n", "\n", "df = DataFrame(data)\n", "df" ] }, { "cell_type": "code", "execution_count": 14, "id": "871130aa", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "piv = df.pivot(index=\"d\", columns=\"exp\", values=\"average\")\n", "piv.plot(logy=True, title=\"Temps de calcul en fonction de la profondeur\");" ] }, { "cell_type": "markdown", "id": "b30bbefe", "metadata": {}, "source": [ "L'hypothèse est vérifiée." ] }, { "cell_type": "code", "execution_count": 17, "id": "8b163d83", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }