{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Classification multi-classe\n",
"\n",
"On cherche à prédire la note d'un vin avec un classifieur multi-classe."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"from teachpyx.datasets import load_wines_dataset\n",
"\n",
"df = load_wines_dataset()\n",
"X = df.drop([\"quality\", \"color\"], axis=1)\n",
"y = df[\"quality\"]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
LogisticRegression(solver='liblinear')
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
],
"text/plain": [
"LogisticRegression(solver='liblinear')"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.linear_model import LogisticRegression\n",
"\n",
"clr = LogisticRegression(solver=\"liblinear\")\n",
"clr.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"55.07692307692308"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy\n",
"\n",
"numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"On regarde la matrice de confusion."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" 0 | \n",
" 1 | \n",
" 2 | \n",
" 3 | \n",
" 4 | \n",
" 5 | \n",
" 6 | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 4 | \n",
" 7 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 47 | \n",
" 22 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 2 | \n",
" 0 | \n",
" 0 | \n",
" 332 | \n",
" 184 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 3 | \n",
" 0 | \n",
" 0 | \n",
" 169 | \n",
" 541 | \n",
" 10 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 4 | \n",
" 0 | \n",
" 0 | \n",
" 19 | \n",
" 217 | \n",
" 22 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 5 | \n",
" 0 | \n",
" 0 | \n",
" 3 | \n",
" 42 | \n",
" 5 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 6 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" 0 1 2 3 4 5 6\n",
"0 0 0 4 7 0 0 0\n",
"1 0 0 47 22 0 0 0\n",
"2 0 0 332 184 0 0 0\n",
"3 0 0 169 541 10 0 0\n",
"4 0 0 19 217 22 0 0\n",
"5 0 0 3 42 5 0 0\n",
"6 0 0 0 1 0 0 0"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import confusion_matrix\n",
"import pandas\n",
"\n",
"pandas.DataFrame(confusion_matrix(y_test, clr.predict(X_test)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"On l'affiche différemment avec le nom des classes."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" 3 | \n",
" 4 | \n",
" 5 | \n",
" 6 | \n",
" 7 | \n",
" 8 | \n",
" 9 | \n",
"
\n",
" \n",
" \n",
" \n",
" 3 | \n",
" 0 | \n",
" 0 | \n",
" 4 | \n",
" 7 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 4 | \n",
" 0 | \n",
" 0 | \n",
" 47 | \n",
" 22 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 5 | \n",
" 0 | \n",
" 0 | \n",
" 332 | \n",
" 184 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 6 | \n",
" 0 | \n",
" 0 | \n",
" 169 | \n",
" 541 | \n",
" 10 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 7 | \n",
" 0 | \n",
" 0 | \n",
" 19 | \n",
" 217 | \n",
" 22 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 8 | \n",
" 0 | \n",
" 0 | \n",
" 3 | \n",
" 42 | \n",
" 5 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 9 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" 3 4 5 6 7 8 9\n",
"3 0 0 4 7 0 0 0\n",
"4 0 0 47 22 0 0 0\n",
"5 0 0 332 184 0 0 0\n",
"6 0 0 169 541 10 0 0\n",
"7 0 0 19 217 22 0 0\n",
"8 0 0 3 42 5 0 0\n",
"9 0 0 0 1 0 0 0"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"conf = confusion_matrix(y_test, clr.predict(X_test))\n",
"dfconf = pandas.DataFrame(conf)\n",
"labels = list(clr.classes_)\n",
"if len(labels) < dfconf.shape[1]:\n",
" labels += [\n",
" 9\n",
" ] # La classe 9 est très représentée, elle est parfois absente en train.\n",
"elif len(labels) > dfconf.shape[1]:\n",
" labels = labels[: dfconf.shape[1]] # ou l'inverse\n",
"dfconf.columns = labels\n",
"dfconf.index = labels\n",
"dfconf"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pas extraordinaire. On applique la stratégie [OneVsRestClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsRestClassifier.html)."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"OneVsRestClassifier(estimator=LogisticRegression(solver='liblinear'))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
],
"text/plain": [
"OneVsRestClassifier(estimator=LogisticRegression(solver='liblinear'))"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.multiclass import OneVsRestClassifier\n",
"\n",
"clr = OneVsRestClassifier(LogisticRegression(solver=\"liblinear\"))\n",
"clr.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"54.95384615384615"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Le modèle logistique régression multi-classe est équivalent à la stratégie *OneVsRest*. Voyons l'autre."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"OneVsOneClassifier(estimator=LogisticRegression(solver='liblinear'))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
],
"text/plain": [
"OneVsOneClassifier(estimator=LogisticRegression(solver='liblinear'))"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.multiclass import OneVsOneClassifier\n",
"\n",
"clr = OneVsOneClassifier(LogisticRegression(solver=\"liblinear\"))\n",
"clr.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"55.138461538461534"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" 3 | \n",
" 4 | \n",
" 5 | \n",
" 6 | \n",
" 7 | \n",
" 8 | \n",
" 9 | \n",
"
\n",
" \n",
" \n",
" \n",
" 3 | \n",
" 0 | \n",
" 0 | \n",
" 5 | \n",
" 6 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 4 | \n",
" 0 | \n",
" 0 | \n",
" 46 | \n",
" 23 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 5 | \n",
" 0 | \n",
" 0 | \n",
" 332 | \n",
" 183 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 6 | \n",
" 0 | \n",
" 0 | \n",
" 169 | \n",
" 524 | \n",
" 27 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 7 | \n",
" 0 | \n",
" 0 | \n",
" 18 | \n",
" 200 | \n",
" 40 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 8 | \n",
" 0 | \n",
" 0 | \n",
" 6 | \n",
" 32 | \n",
" 12 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 9 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" 3 4 5 6 7 8 9\n",
"3 0 0 5 6 0 0 0\n",
"4 0 0 46 23 0 0 0\n",
"5 0 0 332 183 1 0 0\n",
"6 0 0 169 524 27 0 0\n",
"7 0 0 18 200 40 0 0\n",
"8 0 0 6 32 12 0 0\n",
"9 0 0 0 1 0 0 0"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"conf = confusion_matrix(y_test, clr.predict(X_test))\n",
"dfconf = pandas.DataFrame(conf)\n",
"labels = list(clr.classes_)\n",
"if len(labels) < dfconf.shape[1]:\n",
" labels += [\n",
" 9\n",
" ] # La classe 9 est très représentée, elle est parfois absente en train.\n",
"elif len(labels) > dfconf.shape[1]:\n",
" labels = labels[: dfconf.shape[1]] # ou l'inverse\n",
"dfconf.columns = labels\n",
"dfconf.index = labels\n",
"dfconf"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A peu près pareil mais sans doute pas de manière significative. Voyons avec un arbre de décision."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"DecisionTreeClassifier()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
],
"text/plain": [
"DecisionTreeClassifier()"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.tree import DecisionTreeClassifier\n",
"\n",
"clr = DecisionTreeClassifier()\n",
"clr.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"59.323076923076925"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Et avec [OneVsRestClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsRestClassifier.html) :"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"OneVsRestClassifier(estimator=DecisionTreeClassifier())
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
],
"text/plain": [
"OneVsRestClassifier(estimator=DecisionTreeClassifier())"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clr = OneVsRestClassifier(DecisionTreeClassifier())\n",
"clr.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"53.35384615384615"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Et avec [OneVsOneClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsOneClassifier.html)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"OneVsOneClassifier(estimator=DecisionTreeClassifier())
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
],
"text/plain": [
"OneVsOneClassifier(estimator=DecisionTreeClassifier())"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clr = OneVsOneClassifier(DecisionTreeClassifier())\n",
"clr.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"62.58461538461538"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Mieux."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"RandomForestClassifier()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
],
"text/plain": [
"RandomForestClassifier()"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"\n",
"clr = RandomForestClassifier()\n",
"clr.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"69.2923076923077"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"OneVsRestClassifier(estimator=RandomForestClassifier())
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
],
"text/plain": [
"OneVsRestClassifier(estimator=RandomForestClassifier())"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clr = OneVsRestClassifier(RandomForestClassifier())\n",
"clr.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"69.41538461538461"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Proche, il faut affiner avec une validation croisée."
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"MLPClassifier(hidden_layer_sizes=30, max_iter=600)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
],
"text/plain": [
"MLPClassifier(hidden_layer_sizes=30, max_iter=600)"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.neural_network import MLPClassifier\n",
"\n",
"clr = MLPClassifier(hidden_layer_sizes=30, max_iter=600)\n",
"clr.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"52.800000000000004"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"OneVsRestClassifier(estimator=MLPClassifier(hidden_layer_sizes=30,\n",
" max_iter=600))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
],
"text/plain": [
"OneVsRestClassifier(estimator=MLPClassifier(hidden_layer_sizes=30,\n",
" max_iter=600))"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clr = OneVsRestClassifier(MLPClassifier(hidden_layer_sizes=30, max_iter=600))\n",
"clr.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"52.800000000000004"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pas foudroyant."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}