{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Découpage stratifié apprentissage / test\n",
"\n",
"Lorsqu'une classe est sous-représentée, il y a peu de chances que la répartition apprentissage test conserve la distribution des classes."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from teachpyx.datasets import load_wines_dataset\n",
"\n",
"df = load_wines_dataset()\n",
"X = df.drop([\"quality\", \"color\"], axis=1)\n",
"y = df[\"quality\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"On divise en base d'apprentissage et de test avec la fonction [train_test_split](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" base | \n",
" test | \n",
" train | \n",
" ratio | \n",
"
\n",
" \n",
" y | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" 3 | \n",
" 9 | \n",
" 21 | \n",
" 0.428571 | \n",
"
\n",
" \n",
" 4 | \n",
" 45 | \n",
" 171 | \n",
" 0.263158 | \n",
"
\n",
" \n",
" 5 | \n",
" 553 | \n",
" 1585 | \n",
" 0.348896 | \n",
"
\n",
" \n",
" 6 | \n",
" 702 | \n",
" 2134 | \n",
" 0.328960 | \n",
"
\n",
" \n",
" 7 | \n",
" 261 | \n",
" 818 | \n",
" 0.319071 | \n",
"
\n",
" \n",
" 8 | \n",
" 51 | \n",
" 142 | \n",
" 0.359155 | \n",
"
\n",
" \n",
" 9 | \n",
" 4 | \n",
" 1 | \n",
" 4.000000 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
"base test train ratio\n",
"y \n",
"3 9 21 0.428571\n",
"4 45 171 0.263158\n",
"5 553 1585 0.348896\n",
"6 702 2134 0.328960\n",
"7 261 818 0.319071\n",
"8 51 142 0.359155\n",
"9 4 1 4.000000"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas\n",
"\n",
"ys = pandas.DataFrame(dict(y=y_train))\n",
"ys[\"base\"] = \"train\"\n",
"ys2 = pandas.DataFrame(dict(y=y_test))\n",
"ys2[\"base\"] = \"test\"\n",
"ys = pandas.concat([ys, ys2])\n",
"ys[\"compte\"] = 1\n",
"piv = (\n",
" ys.groupby([\"base\", \"y\"], as_index=False)\n",
" .count()\n",
" .pivot(index=\"y\", columns=\"base\", values=\"compte\")\n",
")\n",
"piv[\"ratio\"] = piv[\"test\"] / piv[\"train\"]\n",
"piv"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"On voit le ratio entre les deux classes est à peu près égal à 1/3 sauf pour les notes sous-représentées. On utilise une répartition stratifiée : la distribution d'une variable, les labels, sera la même dans les bases d'apprentissages et de de tests. On s'inspire de l'exemple [StratifiedShuffleSplit](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html)."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(4352, 2145)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import StratifiedShuffleSplit\n",
"\n",
"split = StratifiedShuffleSplit(n_splits=1, test_size=0.33)\n",
"train_index, test_index = list(split.split(X, y))[0]\n",
"len(train_index), len(test_index)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((4352,), (2145,))"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train, y_train = X.iloc[train_index, :], y[train_index]\n",
"X_test, y_test = X.iloc[test_index, :], y[test_index]\n",
"y_train.shape, y_test.shape"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" base | \n",
" test | \n",
" train | \n",
" ratio | \n",
"
\n",
" \n",
" y | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" 3 | \n",
" 10 | \n",
" 20 | \n",
" 0.500000 | \n",
"
\n",
" \n",
" 4 | \n",
" 71 | \n",
" 145 | \n",
" 0.489655 | \n",
"
\n",
" \n",
" 5 | \n",
" 706 | \n",
" 1432 | \n",
" 0.493017 | \n",
"
\n",
" \n",
" 6 | \n",
" 936 | \n",
" 1900 | \n",
" 0.492632 | \n",
"
\n",
" \n",
" 7 | \n",
" 356 | \n",
" 723 | \n",
" 0.492393 | \n",
"
\n",
" \n",
" 8 | \n",
" 64 | \n",
" 129 | \n",
" 0.496124 | \n",
"
\n",
" \n",
" 9 | \n",
" 2 | \n",
" 3 | \n",
" 0.666667 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
"base test train ratio\n",
"y \n",
"3 10 20 0.500000\n",
"4 71 145 0.489655\n",
"5 706 1432 0.493017\n",
"6 936 1900 0.492632\n",
"7 356 723 0.492393\n",
"8 64 129 0.496124\n",
"9 2 3 0.666667"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ys = pandas.DataFrame(dict(y=y_train))\n",
"ys[\"base\"] = \"train\"\n",
"ys2 = pandas.DataFrame(dict(y=y_test))\n",
"ys2[\"base\"] = \"test\"\n",
"ys = pandas.concat([ys, ys2])\n",
"ys[\"compte\"] = 1\n",
"piv = (\n",
" ys.groupby([\"base\", \"y\"], as_index=False)\n",
" .count()\n",
" .pivot(index=\"y\", columns=\"base\", values=\"compte\")\n",
")\n",
"piv[\"ratio\"] = piv[\"test\"] / piv[\"train\"]\n",
"piv"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Le ratio entre les classes est identique, la classe test contient deux fois moins d'invidivu et c'est vrai pour toutes les classes excepté pour la classe 9 qui contient si peu d'éléments que c'est impossible."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"KNeighborsRegressor(n_neighbors=1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
],
"text/plain": [
"KNeighborsRegressor(n_neighbors=1)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.neighbors import KNeighborsRegressor\n",
"\n",
"knn = KNeighborsRegressor(n_neighbors=1)\n",
"knn.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"prediction = knn.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-0.15256811411591387"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import r2_score\n",
"\n",
"r2_score(y_test, prediction)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Cela n'améliore pas la qualité du modèle mais on est sûr que les classes sous-représentées sont mieux gérées par cette répartition aléatoire stratifiée."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}