{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Classification multi-classe\n", "\n", "On cherche à prédire la note d'un vin avec un classifieur multi-classe." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "scrolled": true }, "outputs": [], "source": [ "from teachpyx.datasets import load_wines_dataset\n", "\n", "df = load_wines_dataset()\n", "X = df.drop([\"quality\", \"color\"], axis=1)\n", "y = df[\"quality\"]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(X, y)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
LogisticRegression(solver='liblinear')
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "LogisticRegression(solver='liblinear')" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.linear_model import LogisticRegression\n", "\n", "clr = LogisticRegression(solver=\"liblinear\")\n", "clr.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "55.07692307692308" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy\n", "\n", "numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "On regarde la matrice de confusion." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456
00047000
1004722000
200332184000
3001695411000
400192172200
500342500
60001000
\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 6\n", "0 0 0 4 7 0 0 0\n", "1 0 0 47 22 0 0 0\n", "2 0 0 332 184 0 0 0\n", "3 0 0 169 541 10 0 0\n", "4 0 0 19 217 22 0 0\n", "5 0 0 3 42 5 0 0\n", "6 0 0 0 1 0 0 0" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.metrics import confusion_matrix\n", "import pandas\n", "\n", "pandas.DataFrame(confusion_matrix(y_test, clr.predict(X_test)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "On l'affiche différemment avec le nom des classes." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
3456789
30047000
4004722000
500332184000
6001695411000
700192172200
800342500
90001000
\n", "
" ], "text/plain": [ " 3 4 5 6 7 8 9\n", "3 0 0 4 7 0 0 0\n", "4 0 0 47 22 0 0 0\n", "5 0 0 332 184 0 0 0\n", "6 0 0 169 541 10 0 0\n", "7 0 0 19 217 22 0 0\n", "8 0 0 3 42 5 0 0\n", "9 0 0 0 1 0 0 0" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "conf = confusion_matrix(y_test, clr.predict(X_test))\n", "dfconf = pandas.DataFrame(conf)\n", "labels = list(clr.classes_)\n", "if len(labels) < dfconf.shape[1]:\n", " labels += [\n", " 9\n", " ] # La classe 9 est très représentée, elle est parfois absente en train.\n", "elif len(labels) > dfconf.shape[1]:\n", " labels = labels[: dfconf.shape[1]] # ou l'inverse\n", "dfconf.columns = labels\n", "dfconf.index = labels\n", "dfconf" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pas extraordinaire. On applique la stratégie [OneVsRestClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsRestClassifier.html)." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
OneVsRestClassifier(estimator=LogisticRegression(solver='liblinear'))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "OneVsRestClassifier(estimator=LogisticRegression(solver='liblinear'))" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.multiclass import OneVsRestClassifier\n", "\n", "clr = OneVsRestClassifier(LogisticRegression(solver=\"liblinear\"))\n", "clr.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "54.95384615384615" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Le modèle logistique régression multi-classe est équivalent à la stratégie *OneVsRest*. Voyons l'autre." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
OneVsOneClassifier(estimator=LogisticRegression(solver='liblinear'))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "OneVsOneClassifier(estimator=LogisticRegression(solver='liblinear'))" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.multiclass import OneVsOneClassifier\n", "\n", "clr = OneVsOneClassifier(LogisticRegression(solver=\"liblinear\"))\n", "clr.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "55.138461538461534" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
3456789
30056000
4004623000
500332183100
6001695242700
700182004000
8006321200
90001000
\n", "
" ], "text/plain": [ " 3 4 5 6 7 8 9\n", "3 0 0 5 6 0 0 0\n", "4 0 0 46 23 0 0 0\n", "5 0 0 332 183 1 0 0\n", "6 0 0 169 524 27 0 0\n", "7 0 0 18 200 40 0 0\n", "8 0 0 6 32 12 0 0\n", "9 0 0 0 1 0 0 0" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "conf = confusion_matrix(y_test, clr.predict(X_test))\n", "dfconf = pandas.DataFrame(conf)\n", "labels = list(clr.classes_)\n", "if len(labels) < dfconf.shape[1]:\n", " labels += [\n", " 9\n", " ] # La classe 9 est très représentée, elle est parfois absente en train.\n", "elif len(labels) > dfconf.shape[1]:\n", " labels = labels[: dfconf.shape[1]] # ou l'inverse\n", "dfconf.columns = labels\n", "dfconf.index = labels\n", "dfconf" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A peu près pareil mais sans doute pas de manière significative. Voyons avec un arbre de décision." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
DecisionTreeClassifier()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "DecisionTreeClassifier()" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.tree import DecisionTreeClassifier\n", "\n", "clr = DecisionTreeClassifier()\n", "clr.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "59.323076923076925" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Et avec [OneVsRestClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsRestClassifier.html) :" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
OneVsRestClassifier(estimator=DecisionTreeClassifier())
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "OneVsRestClassifier(estimator=DecisionTreeClassifier())" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clr = OneVsRestClassifier(DecisionTreeClassifier())\n", "clr.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "53.35384615384615" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Et avec [OneVsOneClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsOneClassifier.html)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
OneVsOneClassifier(estimator=DecisionTreeClassifier())
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "OneVsOneClassifier(estimator=DecisionTreeClassifier())" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clr = OneVsOneClassifier(DecisionTreeClassifier())\n", "clr.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "62.58461538461538" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Mieux." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
RandomForestClassifier()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "RandomForestClassifier()" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.ensemble import RandomForestClassifier\n", "\n", "clr = RandomForestClassifier()\n", "clr.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "69.2923076923077" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
OneVsRestClassifier(estimator=RandomForestClassifier())
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "OneVsRestClassifier(estimator=RandomForestClassifier())" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clr = OneVsRestClassifier(RandomForestClassifier())\n", "clr.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "69.41538461538461" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Proche, il faut affiner avec une validation croisée." ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
MLPClassifier(hidden_layer_sizes=30, max_iter=600)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "MLPClassifier(hidden_layer_sizes=30, max_iter=600)" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.neural_network import MLPClassifier\n", "\n", "clr = MLPClassifier(hidden_layer_sizes=30, max_iter=600)\n", "clr.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "52.800000000000004" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
OneVsRestClassifier(estimator=MLPClassifier(hidden_layer_sizes=30,\n",
       "                                            max_iter=600))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "OneVsRestClassifier(estimator=MLPClassifier(hidden_layer_sizes=30,\n", " max_iter=600))" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clr = OneVsRestClassifier(MLPClassifier(hidden_layer_sizes=30, max_iter=600))\n", "clr.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "52.800000000000004" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pas foudroyant." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }