{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Plus proches voisins\n", "\n", "On cherche à prédire la note d'un vin avec un modèle des plus proches voisins." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
fixed_acidityvolatile_aciditycitric_acidresidual_sugarchloridesfree_sulfur_dioxidetotal_sulfur_dioxidedensitypHsulphatesalcoholqualitycolor
07.40.700.001.90.07611.034.00.99783.510.569.45red
17.80.880.002.60.09825.067.00.99683.200.689.85red
27.80.760.042.30.09215.054.00.99703.260.659.85red
311.20.280.561.90.07517.060.00.99803.160.589.86red
47.40.700.001.90.07611.034.00.99783.510.569.45red
\n", "
" ], "text/plain": [ " fixed_acidity volatile_acidity citric_acid residual_sugar chlorides \\\n", "0 7.4 0.70 0.00 1.9 0.076 \n", "1 7.8 0.88 0.00 2.6 0.098 \n", "2 7.8 0.76 0.04 2.3 0.092 \n", "3 11.2 0.28 0.56 1.9 0.075 \n", "4 7.4 0.70 0.00 1.9 0.076 \n", "\n", " free_sulfur_dioxide total_sulfur_dioxide density pH sulphates \\\n", "0 11.0 34.0 0.9978 3.51 0.56 \n", "1 25.0 67.0 0.9968 3.20 0.68 \n", "2 15.0 54.0 0.9970 3.26 0.65 \n", "3 17.0 60.0 0.9980 3.16 0.58 \n", "4 11.0 34.0 0.9978 3.51 0.56 \n", "\n", " alcohol quality color \n", "0 9.4 5 red \n", "1 9.8 5 red \n", "2 9.8 5 red \n", "3 9.8 6 red \n", "4 9.4 5 red " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from teachpyx.datasets import load_wines_dataset\n", "\n", "df = load_wines_dataset()\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Les premières colonnes contiennent les composants du vins ou plutôt les résultats de mesures chimiques. Cela exclut la qualité et la couleur qui n'est pas une information numérique. On les appelle les **variables** ou **features**. C'est ce qui est connu au moment de faire une prédiction." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "X = df.drop([\"quality\", \"color\"], axis=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "La colonne *qualité* est ce qu'on cherche à prédire." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "y = df[\"quality\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "On cale un modèle des plus proches voisins. Celui mémorise simplement tous les exemples. Le paramètre *n_neighbors* précise le nombre de voisins qui prennent pas à la prédiction. Dans notre cas, c'est juste 1." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
KNeighborsRegressor(n_neighbors=1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "KNeighborsRegressor(n_neighbors=1)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.neighbors import KNeighborsRegressor\n", "\n", "knn = KNeighborsRegressor(n_neighbors=1)\n", "knn.fit(X, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "On regarde la prédiction pour un vin quelconque tiré de la base." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([6.])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import random\n", "\n", "i = random.randint(0, df.shape[0] - 1)\n", "vin = X.iloc[i : i + 1, :]\n", "knn.predict(vin)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "On la compare à la note enregistrée dans la base." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "6" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y[i]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Tout se passe bien." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }