diff --git a/.vscode/settings.json b/.vscode/settings.json index 696de624..1e3b2f16 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,7 +1,7 @@ { "python.linting.pylintEnabled": true, "python.linting.enabled": true, - "python.pythonPath": "D:\\anaconda3\\python.exe", + "python.pythonPath": "C:\\Python\\python.exe", "files.associations": { "xstring": "cpp", "deque": "cpp", diff --git a/Sklearn/sklearn-cookbook-zh/4.ipynb b/Sklearn/sklearn-cookbook-zh/4.ipynb index fca47bf5..7a164748 100644 --- a/Sklearn/sklearn-cookbook-zh/4.ipynb +++ b/Sklearn/sklearn-cookbook-zh/4.ipynb @@ -10,7 +10,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.3-final" + "version": "3.8.0-final" }, "orig_nbformat": 2, "kernelspec": { @@ -133,35 +133,519 @@ }, { "cell_type": "code", - "execution_count": 15, - "metadata": {}, + "execution_count": 21, + "metadata": { + "tags": [] + }, "outputs": [ { - "output_type": "error", - "ename": "ModuleNotFoundError", - "evalue": "No module named 'pydot'", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 7\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mio\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mStringIO\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 8\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0msklearn\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mtree\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 9\u001b[1;33m \u001b[1;32mimport\u001b[0m \u001b[0mpydot\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 10\u001b[0m \u001b[0mstr_buffer\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mStringIO\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[0mtree\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexport_graphviz\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdt\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mout_file\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mstr_buffer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'pydot'" + "output_type": "stream", + "name": "stdout", + "text": [ + "32 32 20\n" ] } ], "source": [ "from sklearn import datasets \n", - "X, y = datasets.make_classification(1000, 20, n_informative=3)\n", - "from sklearn.tree import DecisionTreeClassifier \n", - "dt = DecisionTreeClassifier() \n", - "dt.fit(X, y) \n", + "from sklearn.tree import DecisionTreeClassifier\n", + "import numpy as np\n", "\n", - "from io import StringIO\n", - "from sklearn import tree\n", - "import pydot\n", - "str_buffer = StringIO()\n", - "tree.export_graphviz(dt,out_file=str_buffer)\n", - "graph = pydot.graph_from_dot_data(str_buffer.ge)\n", - "graph.write(\"myfile.jpg\")" + "X, y = datasets.make_classification(1000, 20, n_informative=3)\n", + "training = np.random.choice([True,False],p=[.75,.25],size=len(y))\n", + "testing = ~training\n", + "\n", + "dt1 = DecisionTreeClassifier() .fit(X[training], y[training])\n", + "pred1 = dt1.predict(X[testing])\n", + "ac1 = (pred1!=y[testing]).sum(axis=0)\n", + "\n", + "# 限定最大深度,防止过拟合\n", + "dt2 = DecisionTreeClassifier(max_depth=5).fit(X[training],y[training])\n", + "pred2 = dt1.predict(X[testing])\n", + "ac2 = (pred2!=y[testing]).sum(axis=0)\n", + "# dt.get_n_leaves()\n", + "\n", + "# 限定划分信息数的方法是为entropy。和最大的叶节点,防止出现过拟合\n", + "dt3 = DecisionTreeClassifier(min_samples_leaf=10,criterion='entropy',max_depth=5).fit(X[training],y[training])\n", + "pred3 = dt3.predict(X[testing])\n", + "ac3 = (pred3!=y[testing]).sum(axis=0)\n", + "\n", + "print(ac1,ac2,ac3)\n", + "# pydot的接口已经变了\n", + "# from io import StringIO\n", + "# from sklearn import tree\n", + "# import pydot\n", + "# str_buffer = StringIO()\n", + "# tree.export_graphviz(dt,out_file=str_buffer)\n", + "# graph = pydot.graph_from_dot_data(str_buffer.getvalue())\n", + "# help(pydot)" + ] + }, + { + "source": [ + "## 4.3 随机森林" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "983\n 0 1 correct\n0 0.592287 0.407713 False\n1 0.393026 0.606974 True\n2 0.635692 0.364308 True\n3 0.011071 0.988929 True\n4 0.000406 0.999594 True\n.. ... ... ...\n995 0.842913 0.157087 True\n996 0.814767 0.185233 True\n997 0.953311 0.046689 True\n998 0.939430 0.060570 True\n999 0.896015 0.103985 True\n\n[1000 rows x 3 columns]\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "execution_count": 4 + }, + { + "output_type": "display_data", + "data": { + "text/plain": "
", + "image/svg+xml": "\r\n\r\n\r\n\r\n \r\n \r\n \r\n \r\n 2021-03-22T11:15:09.412362\r\n image/svg+xml\r\n \r\n \r\n Matplotlib v3.3.2, https://matplotlib.org/\r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n\r\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAasAAAFPCAYAAADzxOH5AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAATQElEQVR4nO3df5BddX2H8edNQowVihKWTptNTIDUTsCfXUB0hlrUEbCTOFXbpFJErKktqcxQncapMhZrqzK1tmNqpdXWOmoEWmtaUiiDWqv1R4KlaKCRFNBs2o4xIFJtDImf/rE35LrdsDfszd4vuc9rJsOec7+c+8kfmWfO2XPPTVUhSVLLjhn0AJIkTcdYSZKaZ6wkSc0zVpKk5hkrSVLzjJUkqXk9xSrJ+Um2JdmeZN0h1vxCkjuSbE3ykf6OKUkaZpnuc1ZJ5gBfA14IjAObgdVVdUfXmmXAtcB5VXV/kpOr6ptHbmxJ0jDp5czqLGB7Vd1dVXuBDcDKSWteA6yvqvsBDJUkqZ/m9rBmIbCja3scOHvSmp8ESPI5YA7wlqq68ZEOetJJJ9WSJUt6n1SSdNS79dZbv1VVI5P39xKrXswFlgHPA0aBzyR5alV9u3tRkjXAGoDFixezZcuWPr29JOlokOTrU+3v5TLgTmBR1/ZoZ1+3cWBjVT1UVfcw8TuuZZMPVFXXVNVYVY2NjPy/cEqSNKVeYrUZWJZkaZJ5wCpg46Q1f8vEWRVJTmLisuDd/RtTkjTMpo1VVe0D1gI3AXcC11bV1iRXJVnRWXYTsDvJHcCngDdU1e4jNbQkabhMe+v6kTI2Nlb+zkqS+uuhhx5ifHycPXv2DHqURzR//nxGR0c59thjf2h/kluramzy+n7dYCFJasD4+DjHH388S5YsIcmgx5lSVbF7927Gx8dZunRpT/+Pj1uSpKPInj17WLBgQbOhAkjCggULDuvsz1hJ0lGm5VAdcLgzGitJUt/deOONPOUpT+G0007j7W9/+4yP5++sJOkotmTdDX093r1vf/G0a/bv389ll13GzTffzOjoKGeeeSYrVqxg+fLlj/p9PbOSJPXVl770JU477TROOeUU5s2bx6pVq/jEJz4xo2MaK0lSX+3cuZNFiw4++Gh0dJSdOyc/+OjweBlwBvp9ej0IvZzSS9KgeWYlSeqrhQsXsmPHwS/rGB8fZ+HChTM6prGSJPXVmWeeyV133cU999zD3r172bBhAytWrJj+f3wEXgaUJPXV3Llzec973sOLXvQi9u/fz6WXXsrpp58+s2P2aTZJUoMG9XvpCy+8kAsvvLBvx/MyoCSpecZKktQ8YyVJap6xkiQ1z1hJkppnrCRJzTNWkqS+uvTSSzn55JM544wz+nZMP2clSUezt5zQ5+M9MO2SSy65hLVr13LxxRf37W09s5Ik9dW5557LiSee2NdjGitJUvOMlSSpecZKktQ8YyVJap6xkiT11erVqznnnHPYtm0bo6OjvP/975/xMb11XZKOZj3cat5vH/3oR/t+TM+sJEnNM1aSpOYZK0lS84yVJB1lqmrQI0zrcGc0VpJ0FJk/fz67d+9uOlhVxe7du5k/f37P/493A0rSUWR0dJTx8XF27do16FEe0fz58xkdHe15vbGSpKPIsccey9KlSwc9Rt95GVCS1DxjJUlqXk+xSnJ+km1JtidZN8XrlyTZleS2zp9f6f+okqRhNe3vrJLMAdYDLwTGgc1JNlbVHZOWfqyq1h6BGSVJQ66XM6uzgO1VdXdV7QU2ACuP7FiSJB3US6wWAju6tsc7+yZ7aZLbk1yfZFFfppMkif7duv53wEer6vtJfhX4IHDe5EVJ1gBrABYvXtynt5Y0aEvW3TDoEWbs3vm/NOgRZm4AT1ifLb2cWe0Eus+URjv7HlZVu6vq+53NPwd+eqoDVdU1VTVWVWMjIyOPZl5J0hDqJVabgWVJliaZB6wCNnYvSPLjXZsrgDv7N6IkadhNexmwqvYlWQvcBMwBPlBVW5NcBWypqo3A65KsAPYB9wGXHMGZJUlDpqffWVXVJmDTpH1Xdv38RuCN/R1NkqQJPsFCktQ8YyVJap6xkiQ1z1hJkppnrCRJzTNWkqTmGStJUvOMlSSpecZKktQ8YyVJap6xkiQ1z1hJkppnrCRJzTNWkqTmGStJUvOMlSSpecZKktQ8YyVJap6xkiQ1z1hJkppnrCRJzTNWkqTmGStJUvOMlSSpecZKktQ8YyVJap6xkiQ1z1hJkpo3d9ADaMDecsKgJ5i5tzww6AkkHWGeWUmSmmesJEnNM1aSpOYZK0lS84yVJKl5xkqS1DxjJUlqXk+xSnJ+km1JtidZ9wjrXpqkkoz1b0RJ0rCbNlZJ5gDrgQuA5cDqJMunWHc8cDnwxX4PKUkabr2cWZ0FbK+qu6tqL7ABWDnFurcC7wD29HE+SZJ6itVCYEfX9nhn38OSPAtYVFU39HE2SZKAPtxgkeQY4F3Ab/awdk2SLUm27Nq1a6ZvLUkaEr3EaiewqGt7tLPvgOOBM4BPJ7kXeDawcaqbLKrqmqoaq6qxkZGRRz+1JGmo9BKrzcCyJEuTzANWARsPvFhVD1TVSVW1pKqWAF8AVlTVliMysSRp6Ewbq6raB6wFbgLuBK6tqq1Jrkqy4kgPKElST99nVVWbgE2T9l15iLXPm/lYkiQd5BMsJEnNM1aSpOYZK0lS84yVJKl5xkqS1DxjJUlqnrGSJDXPWEmSmmesJEnNM1aSpOYZK0lS84yVJKl5xkqS1DxjJUlqnrGSJDXPWEmSmmesJEnNM1aSpOYZK0lS84yVJKl5xkqS1DxjJUlqnrGSJDXPWEmSmmesJEnNM1aSpOYZK0lS84yVJKl5xkqS1DxjJUlqnrGSJDXPWEmSmmesJEnNM1aSpOYZK0lS84yVJKl5xkqS1LyeYpXk/CTbkmxPsm6K11+b5CtJbkvy2STL+z+qJGlYTRurJHOA9cAFwHJg9RQx+khVPbWqngG8E3hXvweVJA2vXs6szgK2V9XdVbUX2ACs7F5QVd/p2nwCUP0bUZI07Ob2sGYhsKNrexw4e/KiJJcBVwDzgPOmOlCSNcAagMWLFx/urJKkIdW3Gyyqan1VnQr8FvCmQ6y5pqrGqmpsZGSkX28tSTrK9RKrncCiru3Rzr5D2QC8ZAYzSZL0Q3qJ1WZgWZKlSeYBq4CN3QuSLOvafDFwV/9GlCQNu2l/Z1VV+5KsBW4C5gAfqKqtSa4CtlTVRmBtkhcADwH3A688kkNLkoZLLzdYUFWbgE2T9l3Z9fPlfZ5LkqSH+QQLSVLzjJUkqXnGSpLUPGMlSWqesZIkNc9YSZKaZ6wkSc0zVpKk5hkrSVLzjJUkqXnGSpLUPGMlSWqesZIkNc9YSZKaZ6wkSc0zVpKk5hkrSVLzjJUkqXnGSpLUPGMlSWqesZIkNc9YSZKaZ6wkSc0zVpKk5hkrSVLzjJUkqXnGSpLUPGMlSWqesZIkNc9YSZKaZ6wkSc0zVpKk5hkrSVLzjJUkqXnGSpLUPGMlSWpeT7FKcn6SbUm2J1k3xetXJLkjye1Jbkny5P6PKkkaVtPGKskcYD1wAbAcWJ1k+aRl/wqMVdXTgOuBd/Z7UEnS8OrlzOosYHtV3V1Ve4ENwMruBVX1qar6XmfzC8Bof8eUJA2zXmK1ENjRtT3e2Xcorwb+YaoXkqxJsiXJll27dvU+pSRpqPX1BoskFwFjwNVTvV5V11TVWFWNjYyM9POtJUlHsbk9rNkJLOraHu3s+yFJXgD8NvAzVfX9/ownSVJvZ1abgWVJliaZB6wCNnYvSPJM4H3Aiqr6Zv/HlCQNs2ljVVX7gLXATcCdwLVVtTXJVUlWdJZdDRwHXJfktiQbD3E4SZIOWy+XAamqTcCmSfuu7Pr5BX2eS5Kkh/kEC0lS84yVJKl5xkqS1DxjJUlqnrGSJDXPWEmSmmesJEnNM1aSpOYZK0lS84yVJKl5xkqS1DxjJUlqnrGSJDXPWEmSmmesJEnNM1aSpOYZK0lS84yVJKl5xkqS1DxjJUlqnrGSJDXPWEmSmmesJEnNM1aSpOYZK0lS84yVJKl5xkqS1DxjJUlqnrGSJDXPWEmSmmesJEnNM1aSpOYZK0lS84yVJKl5xkqS1DxjJUlqXk+xSnJ+km1JtidZN8Xr5yb5cpJ9SV7W/zElScNs2lglmQOsBy4AlgOrkyyftOwbwCXAR/o9oCRJc3tYcxawvaruBkiyAVgJ3HFgQVXd23ntB0dgRknSkOvlMuBCYEfX9nhn32FLsibJliRbdu3a9WgOIUkaQrN6g0VVXVNVY1U1NjIyMptvLUl6DOslVjuBRV3bo519kiTNil5itRlYlmRpknnAKmDjkR1LkqSDpo1VVe0D1gI3AXcC11bV1iRXJVkBkOTMJOPAy4H3Jdl6JIeWJA2XXu4GpKo2AZsm7buy6+fNTFwelCSp73yChSSpecZKktQ8YyVJap6xkiQ1z1hJkppnrCRJzTNWkqTmGStJUvOMlSSpecZKktQ8YyVJap6xkiQ1z1hJkppnrCRJzTNWkqTmGStJUvOMlSSpecZKktQ8YyVJap6xkiQ1z1hJkppnrCRJzTNWkqTmGStJUvOMlSSpecZKktQ8YyVJap6xkiQ1z1hJkppnrCRJzTNWkqTmGStJUvOMlSSpecZKktQ8YyVJap6xkiQ1r6dYJTk/ybYk25Osm+L1xyX5WOf1LyZZ0vdJJUlDa9pYJZkDrAcuAJYDq5Msn7Ts1cD9VXUa8IfAO/o9qCRpePVyZnUWsL2q7q6qvcAGYOWkNSuBD3Z+vh54fpL0b0xJ0jDrJVYLgR1d2+OdfVOuqap9wAPAgn4MKEnS3Nl8syRrgDWdzf9Jsm0231//X+Ak4FuDnmNGfseTeM2c/xaa8eSpdvYSq53Aoq7t0c6+qdaMJ5kLnADsnnygqroGuKaXaTU7kmypqrFBzyENmv8W2tbLZcDNwLIkS5PMA1YBGyet2Qi8svPzy4BPVlX1b0xJ0jCb9syqqvYlWQvcBMwBPlBVW5NcBWypqo3A+4EPJdkO3MdE0CRJ6ot4AjTckqzpXJ6Vhpr/FtpmrCRJzfNxS5Kk5hkrSVLzjJWkoZQJFyW5srO9OMlZg55LUzNWQyjJjyR5c5I/62wvS/Jzg55LmmV/ApwDrO5sP8jEc1DVIGM1nP4C+D4T/1Bh4kPdvzu4caSBOLuqLgP2AFTV/cC8wY6kQzFWw+nUqnon8BBAVX0POCqe0yIdhoc63ypRAElGgB8MdiQdirEaTnuTPJ6D/0hPZeJMSxomfwx8HDg5yduAzwK/N9iRdCh+zmoIJXkh8CYmvp/sH4HnApdU1acHOZc025L8FPB8Jq4s3FJVdw54JB2CsRpSSRYAz2biH+kXquqx/bRp6TAlWTzV/qr6xmzPoukZqyGU5LnAbVX13SQXAc8C/qiqvj7g0aRZk+QrTFwKDzAfWApsq6rTBzqYpuTvrIbTe4HvJXk6cAXwH8BfDXYkaXZV1VOr6mmd/y5j4lvRPz/ouTQ1YzWc9nW+wmUlsL6q1gPHD3gmaaCq6svA2YOeQ1Ob1W8KVjMeTPJG4CLg3CTHAMcOeCZpViW5omvzGCYuh//ngMbRNDyzGk6/yMSt6q+uqv9m4tufrx7sSNKsO77rz+OAG5i42qAGeYOFpKHT+TDwO6rq9YOeRb3xMuAQSfIgnQ8CT34JqKr60VkeSZp1SeZ2vgH9uYOeRb3zzErSUEny5ap6VpL3AguB64DvHni9qv5mYMPpkDyzGmJJTmbi8yWAH4bU0JkP7AbO4+DnrQowVg0yVkMoyQrgD4CfAL4JPBm4E/DDkBoGJ3fuBPwqByN1gJeaGuXdgMPprUw8aulrVbWUiWejfWGwI0mzZg5wXOfP8V0/H/ijBnlmNZweqqrdSY5JckxVfSrJuwc9lDRL/quqrhr0EDo8xmo4fTvJccBngA8n+SZdv2CWjnJ+d9tjkHcDDpEki6vqG0meAPwvE5eBXwGcAHy4qnYPdEBpFiQ5saruG/QcOjzGaogcuGW38/NfV9VLBz2TJPXCGyyGS/flj1MGNoUkHSZjNVzqED9LUtO8DDhEkuxn4kaKAI8HvnfgJXzckqSGGStJUvO8DChJap6xkiQ1z1hJj0FJnpjk1wc9hzRbjJU0AEnmPtJ2D54IGCsNDR+3JM1QkouB1zPxcYDbgTcDHwBOAnYBr+o8OeQvgT3AM4HPJTlx0vZ6YD0wwsSdmq+pqn9P8mPAn3Lws3G/BrwOODXJbcDNVfWG2fi7SoPi3YDSDCQ5Hfg48Jyq+lYnQB8Erq+qDya5FFhRVS/pxOokYGVV7Z9i+xbgtVV1V5Kzgd+vqvOSfAz4fFW9u/N17McBTwL+vqrOmPW/tDQAnllJM3MecF1VfQugqu5Lcg7w853XPwS8s2v9dVW1f/J258HCzwGuSx5+0Mjjut7j4s7x9wMPJHnSEfnbSI0yVtLsmvx0+wPbxwDfrqpnzO440mODN1hIM/NJ4OVJFsDEE72BfwFWdV5/BfDP0x2kqr4D3JPk5Z3jJMnTOy/fwsTvqUgyJ8kJwINMfHGgNBSMlTQDVbUVeBvwT0n+DXgX8BvAq5LcDvwycHmPh3sF8OrOcbYCKzv7Lwd+NslXgFuB5Z2vc/lckq8mubp/fyOpTd5gIUlqnmdWkqTmGStJUvOMlSSpecZKktQ8YyVJap6xkiQ1z1hJkppnrCRJzfs/XbWHzq19vRsAAAAASUVORK5CYII=\n" + }, + "metadata": { + "needs_background": "light" + } + } + ], + "source": [ + "# 创建分类数据\n", + "from sklearn import datasets \n", + "X,y = datasets.make_classification(1000)\n", + "\n", + "# 创建分类器对象\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "rf = RandomForestClassifier(max_depth=9)\n", + "rf.fit(X,y)\n", + "print((y==rf.predict(X)).sum())\n", + "\n", + "# 理解每个预测的不确定性但是并不会绘图pd没学过\n", + "probs = rf.predict_proba(X)\n", + "import pandas as pd \n", + "\n", + "# 创建pd表单,用来做统计学的处理\n", + "probs_df = pd.DataFrame(probs,columns=['0','1'])\n", + "probs_df['correct'] = rf.predict(X)==y\n", + "# np_df = probs_df.to_numpy()\n", + "print(probs_df)\n", + "# print(np_df)\n", + "\n", + "import matplotlib.pyplot as plt\n", + "f,ax = plt.subplots(figsize=(7,5))\n", + "\n", + "# 源代码执行不了。自己对dp也不熟悉。所以只能画一个这个了\n", + "probs_df.groupby('correct').mean().plot(kind='bar',ax=ax)\n", + "\n" + ] + }, + { + "source": [ + "## 4.4 调整随机森林模型" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0.7295373665480427\n" + ] + }, + { + "output_type": "error", + "ename": "TypeError", + "evalue": "confusion_matrix() missing 1 required positional argument: 'y_pred'", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 24\u001b[0m \u001b[0mrf\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mRandomForestClassifier\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmax_features\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mmax_feature\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 25\u001b[0m \u001b[0mrf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mtraining\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0my\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mtraining\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 26\u001b[1;33m \u001b[0mconfusion_matrixes\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mmax_feature\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m\u001b[0mconfusion_matrix\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mtesting\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 27\u001b[0m \u001b[0mrf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mtesting\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mravel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mC:\\Python\\lib\\site-packages\\sklearn\\utils\\validation.py\u001b[0m in \u001b[0;36minner_f\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 61\u001b[0m \u001b[0mextra_args\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m-\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mall_args\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 62\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mextra_args\u001b[0m \u001b[1;33m<=\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 63\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 64\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 65\u001b[0m \u001b[1;31m# extra_args > 0\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mTypeError\u001b[0m: confusion_matrix() missing 1 required positional argument: 'y_pred'" + ] + } + ], + "source": [ + "# 创建数据集\n", + "from sklearn import datasets\n", + "X,y = datasets.make_classification(n_samples=10000,n_features=20,n_informative=15,flip_y=.5,weights=[.2,.8])\n", + "\n", + "# 创建训练集的bool屏蔽\n", + "import numpy as np \n", + "training = np.random.choice([True,False],p=[.8,.2],size=y.shape)\n", + "testing =~training\n", + "\n", + "# 创建随机森林的分类器\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "rf = RandomForestClassifier()\n", + "rf.fit(X[training],y[training])\n", + "preds = rf.predict(X[testing])\n", + "print((preds == y[testing]).mean())\n", + "# 0.729\n", + "\n", + "# 迭代max_features 选项查看变化 .分类中用到features多少\n", + "from sklearn.metrics import confusion_matrix\n", + "max_feature_parms = ['auto','sqrt','log2',.01,.5,.99]\n", + "confusion_matrixes = {}\n", + "\n", + "for max_feature in max_feature_parms:\n", + " rf = RandomForestClassifier(max_features=max_feature)\n", + " rf.fit(X[training],y[training])\n", + "confusion_matrixes[max_feature] =confusion_matrix(y[testing])\n", + "rf.predict(X[testing]).ravel()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# 可视化,显然已经运行不聊了\n", + "# import pandas as pd \n", + "# confusion_df = pd.DataFrame(confusion_matrixes)\n", + "\n", + "# import itertools\n", + "# from matplotlib import pyplot as plt \n", + "# f,ax = plt.subplots(figsize=(7,5))\n", + "# confusion_df.plot(kind = 'bar',ax=ax)\n", + "\n", + "# ax.legend(loc='best')\n", + "# ax.set_title(\"Guessed vs Correct\")\n", + "# ax.grid()\n", + "\n", + "# ax.set_xticklabels([str((i,j))for i,j in list(iteratools.product(range(2),range(2)))])\n", + "# ax.set_xlabel(\"gussed vs correct\")\n", + "# ax.set_ylabel(\"correct\")" + ] + }, + { + "source": [ + "## 4.5 使用支持向量机对数据分类" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "execution_count": 30 + }, + { + "output_type": "display_data", + "data": { + "text/plain": "
", + "image/svg+xml": "\r\n\r\n\r\n\r\n \r\n \r\n \r\n \r\n 2021-03-22T12:57:02.907255\r\n image/svg+xml\r\n \r\n \r\n Matplotlib v3.3.2, https://matplotlib.org/\r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n\r\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAe8klEQVR4nO3df5Ac5X3n8fd3VyvMSsJaSQYJiQgTZNXpfCBhSYayfbEPTiCViIJjQHDlmDg52S64q7jsio3JGc4+58A5x0XinDldTM7hsPgRfivC/HCOkEvxQwIJDDI6A5FKEgICuwKhBUva/d4fM7Oane2enpnume7p/ryqtjTT3dv9TO/o+zz9fZ5+2twdERHJv560CyAiIp2hgC8iUhAK+CIiBaGALyJSEAr4IiIFMSntAtRz3DGT/PipfWkXQ0Ska7w0+N4b7v6BoHWZDvjHT+3jT889Oe1iiIh0jTUbXtgVtk4pHRGRglDAFxEpCAV8EZGCUMAXESkIBXwRkYLI9CgdkbRt7l/CxukrGeodYGBkiNX772fZ8Na0iyXSEgV8kRCb+5dwy4wLOdwzGYChSTO4ZcaFAAr60pWU0hEJsXH6yrFgX3G4ZzIbp69MqUQi8Sjgi4QY6h1oarlI1ingi4QYGBlqarlI1imHLxJi9f77x+XwAfpGD7F6//2hv6NOXskyBXyREJVA3WgAVyevZJ1SOiIhmm2tq5NXsk4tfJEArbTW1ckrWacWvkiAVlrr6uSVrFMLXyRAK631Vjp5pTt1a+e8Ar5IgIGRIYYmzQhcHqbZTl7pTt3cOa+ALxKg1db6suGtmf9PL/HUS/dl/W+vgC9C8CX62sHbW2qtd+vlfjvk8Vx0c+e8Ar4USlAAAgIv0dcO3s5/fuWPm95/t17uVyQVpPNwLoK0ku7LiqYCvpndCKwGXnf3D5eXzQBuBU4GdgIXufuET25mnwP+qPz2v7j7j1svtkjzwgJQnx+KdYleHSB7GGXUeifs646BNV0R5KKCdDOVQTenPurp5s75Zodl/i/gvJplXwd+5u4LgJ+V349TrhSuBj4KLAeuNrPsX/9IroQFoOGeKYHbN3KJXgmQQ5NmgNmEYF8x3DOFzf1Lmi90h9UL0rWftVIZhH2ubk591LNseCtrB29n4MgguDNwZJC1g7d3RSXWVAvf3R81s5NrFq8BPll+/WPgEeBrNducCzzk7oMAZvYQpYpjQ3PFlaKLk25oNtA0cokeFCADmXVFy7ZekG62xd7NqY8o3do5n0QO/wR331d+/SpwQsA2c4HdVe/3lJdNYGbrgHUAH+hXF0PRhAX0zf1LuHNgDQd7poAZ0HxOOCwATRk9yCGb3NIlejOVSLtbtknk3usF6WZb7N2c+sirRCOqu7uZecx9rAfWA5w689hY+5J0tBp4wvLHL0+ez5NTlwe2pJvJCYcFoE8P3QO0Nn4+LECGbduqqHOaVAdpvSC9cfrKplrsui8he5II+K+Z2Rx332dmc4DXA7bZy9G0D8A8SqkfyaA4LcU4gScsZfDYtLNCc+PQeMs5KgC1EoiCAmTP6BHMYMSO/veK07Jt5JxGpVsa/ZtGnaNmW+zdmvrIqyQC/r3A54Bry//eE7DNA8AfV3XUrgCuTODYkrC4LcVWR2Zs7l8SGrhHI8YWRLWc2zkWvLKfOwbWjHX+HuvvsfidbWzvX9TwtMrV5Vs0vH3c7/6qJt0EE89pvXRLs3/TsCCtFnv3a3ZY5gZKLfVZZraH0siba4HbzOz3gF3AReVtlwJfdPffd/dBM/s2sLm8q29VOnAlW+IOpWtlZEYlIFVy87V6GGWU4BZ+Iw8k6cRY8MM2eaz8B3un8uTU5Q2N3Agq3z9O+9i4fgo8OLNZfU7r5d6THB5ZXRlUKqqbZl6q4N8lmhqW6e6XuPscd+9z93nu/iN3f9Pdz3b3Be5+TiWQu/sWd//9qt+90d1PLf/8VdIfRJIRdyhdKzNG1hvp0jd6iLMOPEbf6KHxK9zpH3knMqh2Yo76OMcI/Oy1FV9IRVh9Tlfvv3/COapUhlGt/1Y0O0RTskHTI8s4caf4rRd4woRWJu6sHbydi/ffPWHc82ff/AnX7r0mskXZibHgcY7RcDlqWvm157Te2PDQv51Zy0FaD3vpThr3KOPEHUrXSp43LB3Rw+i4/baSLqiX6kgqtx9nvHmjo3wM59iRgwz3TAkta9g5CvqbVrSa2snrTVV5p4Av4yTRMddscA4LSKPWGzvfHlaBLRre3lBuv5FKodVKcnP/En5lk0ut9+q0Te17wK2HwzaZjx34R7b3L+KmmZeycfrKhv42lfU3zbw0MD3USpDO801VeaaALxOkMZSuzw9x2PsmBKS4c6+EVWB3DqyJ7MhstMO3lUqydt8AuDNl9CCLD24LHIp6uGfyhA7dyn0KUSOClg1vbXocfT26qao7KeBLqgIDX424aYLaCmxz/5LSHbsRxwrLU988c+3YfsOOESWss3ayH+Li/XeXAnuQgAoxqBKoLR8kG6Q1RLM7KeBLqhqZiybpNMHG6StDR75MGT049jr0voAEUk1ROfBm7uBt9Koo6SDdiSvBPM6nnyYFfElVVOu9HWmCesc8aMeyuX/J2OiWsKAbN9UUlQMP7NcIyO2HCfuM3XTna17n00+TAr6kKjSouretRVe39dzTy80z13LTzEvpHz1Irx8ZN0VCtco49lZaoFHplaDW+KxDr/PLYxdGdvBWPmO3y+t8+mlSwJdUhQW+ds4vvnr//aEjVoCxztLh3qmleXEYxW3iLSv9owdbboE2kl6pvat1wt3I7ix4dwc733fKhCuBod4Brj7xG12dAtHQz+Qp4Euq0uj8Wza8tRTwGzDaM4kpI+/wrr2P0Z6j/11KFQGxWqDNpFfCOnnfmHz8uGfvVpZD96dANPQzeQr4kro08sr9owcZ7p3a0LYHe6bQy8i4ZWZw0Fp/Ulaz6rV2K+fv6hO/MSFAdnMKREM/k6eAL7lx6/TfKo1fp4ceRjnrwGNcvP/uwG1/e+gebp5x8bhWe1g+vIfRCXn8EZuE+ShOZ/LnjbR285YC0dDP5CngSy7cOv23xo1HH6V3bCx7UNAPCiaLhrdPeNBK3+ghDltf4DEdm9Cp264WaCOt3TymQLppVFE3UMCXXHhs2lmBs0w+Nu2s0FZ+UDA55dCuCS3KsDtUMeOYkfc4xg+1vQXaSGtXKRCJooAvuRD2kJSoh6fUCmtRho3qGe6ZwrW7r2nqGK2Kau0qBSJRFPAlF8IeklI942arlg1vLT1APaCTN2vpEqVApJ7Y8+Gb2UIz21b187aZ/UHNNp80s7eqtvlm3OOKVDvrwGMTnwzlXlqegE8P3dP0PP8iWRO7he/uO4DFAGbWS+mB5XcFbPoP7r467vFEglTy9I2O0mlWN6RLNO+MREk6pXM28JK770p4vyKRLt5/d2IBPkiW0yWad0YakfQjDtcCG0LWnWVmz5jZ/Wb2L8N2YGbrzGyLmW15+70jCRdPJJ/0yEFpRGIB38wmA78J3B6w+mlgvrufDvw5cHfYftx9vbsvdfelx71PfcoijcjbTVfSHkm28FcCT7v7a7Ur3P1td3+n/HoT0GdmsxI8tkihxX34vBRDkgH/EkLSOWY226w0iNnMlpeP+2aCxxYptNX779coIomUSM7EzKYA/xb4QtWyLwK4+w3AZ4AvmdkR4F1grXvtGDoRaVU3jCKS9CUS8N39IDCzZtkNVa9/APwgiWOJSLAsjyKSbEh6lI6IiGSUAr6ISEEo4IuIFIQCvohIQSjgi4gUhAK+iEhBKOCLiBSEAr6ISEEo4IuIFIQCvohIQSjgi4gUhAK+iEhBKOCLiBSEAr6ISEEo4IuIFIQCvohIQST5EPOdZvZzM9tmZlsC1puZ/ZmZvWhmz5rZGUkdW0REoiXyxKsqn3L3N0LWrQQWlH8+Cvyw/K+IiHRAJ1M6a4C/9pLHgelmNqeDxxcRKbQkA74DD5rZU2a2LmD9XGB31fs95WXjmNk6M9tiZlvefu9IgsUTESm2JFM6H3f3vWZ2PPCQmb3g7o82uxN3Xw+sBzh15rGeYPlERAotsRa+u+8t//s6cBewvGaTvcBJVe/nlZeJiEgHJBLwzWyKmU2rvAZWAM/VbHYv8Dvl0TpnAm+5+74kji8iItGSSumcANxlZpV9/sTdf2pmXwRw9xuATcAq4EVgGPjdhI4tIiINSCTgu/vLwOkBy2+oeu3A5UkcT0REmqc7bUVECkIBX0SkIBTwRUQKQgFfRKQgFPBFRApCAV9EpCAU8EVECkIBX0SkIBTwRUQKQgFfRKQgFPBFRApCAV9EpCAU8EVECkIBX0SkIBTwRUQKIsln2iZu+qzjWL3u3Nj76Xnkei7/u/WcO+9PEyiViEh3ih3wzewk4K8pPfXKgfXufn3NNp8E7gH+qbzoTnf/VtS+B3tP4Nb3fzVuEZl62VzOu2w/Kx/6Mj0Lto9VAIAqAREpjCRa+EeAr7j70+Xn2j5lZg+5+/aa7f7B3VcncLymvTN0CQC3LwVYBWu+ynmXbWDlQ++nZ0H8K4hmqLIRkbTEDvjlB5HvK78+YGa/AOYCtQE/U94ZuuRoBdBBR682SpVNzyOliyGlnESk3RLN4ZvZycAS4ImA1WeZ2TPAK8BX3f35kH2sA9YBzJo9N8niZULQ1cbUgQ1cP9cTvdpQRSIitaz0bPEEdmQ2Ffh74DvufmfNuuOAUXd/x8xWAde7+4KofZ6y6DT/zs2bEilfEU0dqKStShdb6rwWyb81G154yt2XBq1LpIVvZn3AHcDNtcEewN3frnq9ycz+u5nNcvc3kji+BKtNW029bC7Xz3Xgy2OVQKeoshFJXxKjdAz4EfALdw/832xms4HX3N3NbDml8f9vxj22NOdoBQBp9F1UVzbqvBbpvCRa+B8DPgv83My2lZd9A/g1AHe/AfgM8CUzOwK8C6z1pHJJ0hVqK5vazuukqCIRCZdYDr8dlMOXZk0d2AAw1nehzmspmrbn8EWyolOjoBqhqw3JGgV8yb207rmovcFPVxuSNgV8kTYaV9lUXW0kOVJKFYk0SgFfpIPaMlIqoCLRMFgJooAvkgPho6B0z4UcpYAvkkMTOq87SPdcZJcCvogkSvdcZJcCvoi0VbuuNjTzbPMU8EWkK2XtngvIfmWjgC8iuZHmPRe1lU0WO68V8EVEEtCp2WnjVCQK+CIibdCu2WnjjIJSwBcR6SKRo6A2vBD6uwr4IiJdbOIoqOtDt+3pSIlERCR1CvgiIgWRSMA3s/PMbIeZvWhmXw9Yf4yZ3Vpe/4SZnZzEcUVEpHGxA76Z9QJ/AawEFgGXmNmims1+Dxhy91OB7wPXxT2uiIg0J4kW/nLgRXd/2d0PAbcAa2q2WQP8uPz6b4Czyw8/FxGRDkki4M8Fdle931NeFriNux8B3gJmJnBsERFpUOY6bc1snZltMbMtB4YG0y6OiEhuJBHw9wInVb2fV14WuI2ZTQLeD7wZtDN3X+/uS9196bSBGQkUT0REIJmAvxlYYGYfNLPJwFrg3ppt7gU+V379GeDv3N0TOLaIiDQo9p227n7EzK4AHgB6gRvd/Xkz+xawxd3vBX4E3GRmLwKDlCoFERHpoESmVnD3TcCmmmXfrHr9HnBhEscSEZHWZK7TVkRE2kMBX0SkIBTwRUQKQgFfRKQgFPBFRApCAV9EpCAU8EVECkIBX0SkIBTwRUQKQgFfRKQgFPBFRApCAV9EpCAU8EVECkIBX0SkIBTwRUQKQgFfRKQgFPBFRAoi1hOvzOxPgPOBQ8BLwO+6+/6A7XYCB4AR4Ii7L41zXBERaV7cFv5DwIfd/TTg/wFX1tn2U+6+WMFeRCQdsQK+uz/o7kfKbx8H5sUvkoiItEOSOfzPA/eHrHPgQTN7yszW1duJma0zsy1mtuXA0GCCxRMRKbbIHL6ZPQzMDlh1lbvfU97mKuAIcHPIbj7u7nvN7HjgITN7wd0fDdrQ3dcD6wFOWXSaN/AZRESkAZEB393PqbfezC4DVgNnu3tggHb3veV/Xzezu4DlQGDAFxGR9oiV0jGz84A/BH7T3YdDtpliZtMqr4EVwHNxjisiIs2Lm8P/ATCNUppmm5ndAGBmJ5rZpvI2JwD/18yeAZ4E/tbdfxrzuCIi0qRY4/Dd/dSQ5a8Aq8qvXwZOj3McERGJT3faiogUhAK+iEhBKOCLiBSEAr6ISEEo4IuIFESsUToiWbV11yAPPvsq+4cPM72/jxWnzWbJ/BlpF0skVQr4kjtbdw1y1+Y9HB4p3fi9f/gwd23eA6CgL4WmgC+58+Czr44F+4rDI86Dz75a+ICvK59iU8CX3Nk/fLip5UWhKx9Rp63kzvT+vqaWF0W9Kx8pBgV8yZ0Vp82mr9fGLevrNVacFjTLd3HoykeU0smIIudWk/7sld8t6vkMM72/LzC4R135FPm7mTcK+BlQtNxqdQDpn9zLe4dHGC1nGpL47ApQwVacNnvc9wyir3yK9t3MOwX8DCjSqJLaADJ8aGTCNnE+e9wAlefKopUrnyJ9N4tAAT8DipRbDQogQVr97HEC1N1bdvPES0efo5zH1uyS+TOa+ixF+m4WgQJ+BrSaWw2T5VZqo4Gi1c/eaoDaumtwXLCvKHprNunvpqQr7iMOrzGzveWnXW0zs1Uh251nZjvM7EUz+3qcY+ZRkqNKKimNyn/SSit1666JwSwpW3cNct1927ny1me47r7tdY/VSKCIM6Km1SGZ9YYmFrk1qxFP+ZLEsMzvu/vi8s+m2pVm1gv8BbASWARcYmaLEjhubiyZP4MLls0bC0rT+/u4YNm8llqVnR5r3WwFs+K02fSMjx8YcGxf6asY57NX9t9KgKoX1Psn97ZUljxI8rsp6etESmc58GL5UYeY2S3AGmB7B47dNZrNrYbpdM61lZy5mYEf/Z2eHuP8j8xN5PO3OiQzLHUBpY7lrbsGmy5fllNrzUjquynpSyLgX2FmvwNsAb7i7kM16+cCu6ve7wE+GrYzM1sHrAOYNXtuAsUrlk7nXJutYDY+/Qojo+MriJHRZPPkrQSooCGL1e57am9T+0xrOGNeKhlpj8iUjpk9bGbPBfysAX4I/DqwGNgHfC9ugdx9vbsvdfel0wb0RW1Wp3OuzeTMt+4aDByGCennySupizDvHh5tan9pTGOQRv+NdJfIFr67n9PIjszsfwIbA1btBU6qej+vvEzaoNN3mTZzM0+9YJeFUR9L5s/gtsd3R29YR3ULO0g7K7Ykx8zrSiGfYqV0zGyOu+8rv70AeC5gs83AAjP7IKVAvxa4NM5xpb5O5lybqWDqBbusjPron9wbeBXSSMdtbRonSDsrtqQqGd1dm19xc/jfNbPFgAM7gS8AmNmJwF+6+yp3P2JmVwAPAL3Aje7+fMzjSoY0WsGE9S8c29eTmUCy+owTuePJPeP6GXp7jNVnnBj5u1E3lbV7OGNS/Te6uza/YgV8d/9syPJXgFVV7zcBE4ZsSnZ04hI+LP1z/key0zlf74ol6hzVa0l3Ii3Sylw5QXR3bX7pTtsEdHu+s1OX8J3oX6j9WyycM40d+w40dbygK5ZGzlG9FvbXzm//rSdJnV/dXZtfCvgx5SHfmcQlfFilF7Q8Kvg1WoEGBfendw6N+1vEnRunXids7TkKG9q5cM60ho6VhCT6b5K6UpDs0QNQYsrDU4TiXsKHDQe8e8vupocJNjq0MGi7J14ajJyYrZm/Te0xglSvWzJ/BmecPDBhm6d3DnXV0EjdXZtfauHHlId8Z9xL+LBK78mXB6tvqB1bXu/KodGrjUZn3QzS6N/mvqf2Rh6j9hzt2Hdgwjbd2OGpu2vzSQE/pjzkO+NewocF0NpgH7V9vXW1y+NUqNP7+yLTRlt3DUbebBV0jvLQAJD8UkonpjzMJhj3Ej6scjMLXAwQmuKot6/q2ThbrVD7eo2Fc6ZFpo2i0j5h56heuaJmEhVpNwX8mPKS71wyfwZfO38RF51Zuin6tsd3Nxygwiq95aeEn4ONT7/S8L7g6NVCJTgvnDMtcLtq0/v7+Oivz5jwt9mx70Bkv0u9FvlFZ5401vFcOy30itNm01s7HWjVPjXVgaRJKZ0E5CXf2eqIo3rDAYMeKgLhM1DW7qtmYk2gFJx37DvABcvm1Z0KIWw0UNjvVAf5sFRd/+TesdFHQefqgmXzmNxrvDsanM/qxny+5IcCvoyJMzwzrNKrN+1w2H6r93Xlrc8E/u7+4cMsmT8jdMhkvdRKI/0uYf0alTtu652rqNy/8vmSFgV8GdNKh2NQ5yccbaHXS7s0EviignMrHc6N/s6kHhvbpn9yL6vPOHGsIqp3rupVctVlF+k05fBlTLOPBwwaC3/Hk3v4myd2jy2LO5FYVKd4K30oUb9T+VzVLfXDI+Nb7fXOVVg/RG3ZRTpNLXwZ02xrOSitUftwk3oauQO1kekCWulDqfc7jaS2ws7VwjnTxn6/0v9Q+bcbp92QfFHAT0FW595pdi6WuLnooJuUwsrVyfPTSGor6FzVTu3gXqoEunHUluSTAn6HZX3unWaCa1SuOkpWOy8bvZmu9lxdd992TSssmaYcfoflYe6diqBcdW+PETIMfYKsdl62ejOd7rKVrFMLv8PyFBTCUkDVy/on93L4yAi1IxWz3HnZ6jTDeZhmQ/It7iMObwUWlt9OB/a7++KA7XYCB4AR4Ii7L41z3G6Wt6AQlgIKmk8+i/0WYVrpN0hyWuFuO1/SHeI+8eriymsz+x7wVp3NP+Xub8Q5Xh4Uda7xvNyNXE9SDyDJej+PdK9EcvhmZsBFwIYk9pdneZl7RyZKqlWep34eyZakcvifAF5z91+GrHfgQTNz4H+4+/qwHZnZOmAdwKzZ2XnWaZKK0NotmiRb5Xnq55FsiWzhm9nDZvZcwM+aqs0uoX7r/uPufgawErjczP512Ibuvt7dl7r70mkDCorSHZJslTd7x7NIoyJb+O5+Tr31ZjYJ+DTwkTr72Fv+93UzuwtYDjzaXFFFsivJVnlR+3mk/ZLI4Z8DvODue4JWmtkUM5tWeQ2sAJ5L4LgimZFkq1z9PNIuSeTw11KTzjGzE4G/dPdVwAnAXaV+XSYBP3H3nyZwXOlyeRp6mHSrXP080g6xA767Xxaw7BVgVfn1y8DpcY8j+ZK3oYdJDckUaSfdaSupiPOwlaxSq1yyTnPpSCo09FCk8xTwJRUaeijSeQr4kopWZ6QUkdYphy+pUCenSOcp4Etq1Mkp0llK6YiIFIQCvohIQSjgi4gUhAK+iEhBKOCLiBSEuXv0Vikxs38GdjXxK7OAbnqMosrbXipve6m87dVqeee7+weCVmQ64DfLzLZ00wPSVd72UnnbS+Vtr3aUVykdEZGCUMAXESmIvAX80IejZ5TK214qb3upvO2VeHlzlcMXEZFweWvhi4hICAV8EZGC6PqAb2YXmtnzZjZqZktr1l1pZi+a2Q4zOzetMoYxs9PN7DEz+7mZ3Wdmx6VdpihmttjMHjezbWa2xcyWp12meszs1nJZt5nZTjPblnaZopjZfzCzF8rf6++mXZ56zOwaM9tbdY5XpV2mRpjZV8zMzWxW2mWpx8y+bWbPls/tg2Z2YqwduntX/wD/AlgIPAIsrVq+CHgGOAb4IPAS0Jt2eWvKvhn4jfLrzwPfTrtMDZT5QWBl+fUq4JG0y9RE2b8HfDPtckSU8VPAw8Ax5ffHp12miPJeA3w17XI0WeaTgAco3dQ5K+3yRJT1uKrX/xG4Ic7+ur6F7+6/cPcdAavWALe4+6/c/Z+AF4GstUY/BDxafv0Q8NsplqVRDlSuRN4PvJJiWRpmZgZcBGxIuywRvgRc6+6/AnD311MuTx59H/hDSt/lTHP3t6veTiFmmbs+4NcxF9hd9X5PeVmWPE+pYgK4kFLLI+v+APgTM9sN/DfgynSL07BPAK+5+y/TLkiEDwGfMLMnzOzvzWxZ2gVqwBXltMONZjaQdmHqMbM1wF53fybtsjTKzL5T/v/274BvxtlXVzzxysweBoIednqVu9/T6fI0o17ZKaVx/szM/hNwL3Cok2ULE1Hms4Evu/sdZnYR8CPgnE6Wr1aD349LyEjrPuL8TgJmAGcCy4DbzOwUL1/TpyGivD8Evk2p5fltSmmzz3eudBNFlPcbwIrOlqi+qO+vu18FXGVmVwJXAFe3fKwUv0eJMrNHKOUSt5TfXwng7v+1/P4B4Bp3fyy1QtZhZh8C/re7Zy3tNI6ZvQVMd3cvp0necvdMdzab2SRgL/ARd9+TdnnqMbOfAte5+/8pv38JONPd/zndkkUzs5OBje7+4bTLEsTM/hXwM2C4vGgepZTkcnd/NbWCNcjMfg3YFOf85jmlcy+w1syOMbMPAguAJ1Mu0zhmdnz53x7gj4Ab0i1RQ14BfqP8+t8AWU+RQOkK5IWsB/uyuyl13FYaAZPJ8AyPZjan6u0FwHNplSWKu//c3Y9395Pd/WRKad4zshzszWxB1ds1wAtx9tcVKZ16zOwC4M+BDwB/a2bb3P1cd3/ezG4DtgNHgMvdfSTNsga4xMwuL7++E/irNAvToH8PXF9uNb8HrEu5PI1YS0bSOQ24EbjRzJ6jlOL7XJrpnAZ818wWU0rp7AS+kGpp8udaM1sIjFIaVfTFODvLTUpHRETqy3NKR0REqijgi4gUhAK+iEhBKOCLiBSEAr6ISEEo4IuIFIQCvohIQfx/6ecTaVUWnHkAAAAASUVORK5CYII=\n" + }, + "metadata": { + "needs_background": "light" + } + } + ], + "source": [ + "# 使用默认的支持向量机\n", + "from sklearn import datasets\n", + "X,y = datasets.make_classification()\n", + "from sklearn.svm import SVC\n", + "\n", + "base_svm = SVC()\n", + "base_svm.fit(X,y)\n", + "\n", + "# 使用线性核函数的支持向量机\n", + "X,y = datasets.make_blobs(n_features=2,centers=2)\n", + "from sklearn.svm import LinearSVC\n", + "\n", + "svm = LinearSVC()\n", + "svm.fit(X,y)\n", + "\n", + "import numpy as np \n", + "import matplotlib.pyplot as plt \n", + "\n", + "plot_step =0.2\n", + "x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1\n", + "y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1\n", + "xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),np.arange(y_min, y_max, plot_step))\n", + "\n", + "Z = svm.predict(np.c_[xx.ravel(), yy.ravel()])\n", + "Z = Z.reshape(xx.shape)\n", + "\n", + "clr = ['r','g','b']\n", + "plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)\n", + "plt.scatter(X[:,0],X[:,1],cmap='r')" + ] + }, + { + "source": [ + "## 4.6 使用多分类来归纳" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0.8208092485549133\n20\n" + ] + } + ], + "source": [ + "# 生成数据集\n", + "from sklearn import datasets\n", + "X,y = datasets.make_classification(n_samples=10000,n_classes=3,n_informative=3)\n", + "\n", + "# 分割数据集\n", + "import numpy as np \n", + "training = np.random.choice([True,False],p=[0.9,0.1],size=y.shape)\n", + "testing = ~training\n", + "from sklearn.tree import DecisionTreeClassifier\n", + "dt = DecisionTreeClassifier()\n", + "dt.fit(X[training],y[training])\n", + "re = dt.predict(X[testing])\n", + "# 输出测试集的准确率\n", + "print((re==y[testing]).mean())\n", + "# 0.85\n", + "print(dt.get_depth())" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0.7784200385356455\n" + ] + } + ], + "source": [ + "from sklearn.multiclass import OneVsRestClassifier\n", + "from sklearn.linear_model import LogisticRegression\n", + "\n", + "# 使用两个逻辑回归分类器构建了一个多分类分类器\n", + "mlr = OneVsRestClassifier(LogisticRegression(),n_jobs=2)\n", + "mlr.fit(X[training],y[training])\n", + "re = mlr.predict(X[testing])\n", + "print((re==y[testing]).mean())" + ] + }, + { + "source": [ + "## 4.7 将LDA用于分类" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "## 4.8 使用QDA-非线性LDA" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "## 4.9 使用随机梯度下降来分类" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0.8888888888888888\n" + ] + } + ], + "source": [ + "from sklearn import datasets\n", + "X,y = datasets.make_classification()\n", + "\n", + "\n", + "# 分割数据集\n", + "import numpy as np \n", + "training = np.random.choice([True,False],p=[0.9,0.1],size=y.shape)\n", + "testing = ~training\n", + "\n", + "from sklearn import linear_model\n", + "sgd_clf = linear_model.SGDClassifier()\n", + "\n", + "sgd_clf.fit(X[training],y[training])\n", + "result = sgd_clf.predict(X[testing])\n", + "print((result==y[testing]).mean())" + ] + }, + { + "source": [ + "## 4.10 使用朴素贝叶斯分类" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(1192,)\n" + ] + } + ], + "source": [ + "# 导入数据\n", + "from sklearn import datasets \n", + "\n", + "categories = [\"rec.autos\",\"rec.motorcycles\"]\n", + "new_groups = datasets.fetch_20newsgroups(categories=categories)\n", + "\n", + "# print(\"\\n\".join(new_groups.data[:1]))\n", + "print(new_groups.target.shape)\n", + "X,y = new_groups.data,new_groups.target\n", + "\n", + "# 分割数据集\n", + "import numpy as np \n", + "training = np.random.choice([True,False],p=[0.9,0.1],size=y.shape)\n", + "testing = ~training" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0.9495798319327731\n" + ] + } + ], + "source": [ + "# 使用朴素贝叶斯进行分类。\n", + "# 使用onehot编码\n", + "from sklearn.feature_extraction.text import CountVectorizer\n", + "count_vec = CountVectorizer()\n", + "bow = count_vec.fit_transform(new_groups.data)\n", + "\n", + "# 将矩阵转换为密集数组\n", + "bow = np.array(bow.todense())\n", + "\n", + "words = np.array(count_vec.get_feature_names())\n", + "words[bow[0]>0][:5]\n", + "\n", + "# 朴素贝叶斯\n", + "from sklearn import naive_bayes\n", + "clf = naive_bayes.GaussianNB()\n", + "\n", + "mask = np.random.choice([True,False],len(bow))\n", + "clf.fit(bow[training],new_groups.target[training])\n", + "predictions = clf.predict(bow[testing])\n", + "print((predictions == new_groups.target[testing]).mean())" + ] + }, + { + "source": [ + "## 4.11 标签传递算法。半监督学习" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 80, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "['setosa' 'versicolor' 'virginica']\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "0.98" + ] + }, + "metadata": {}, + "execution_count": 80 + } + ], + "source": [ + "from sklearn import datasets\n", + "iris = datasets.load_iris()\n", + "X = iris.data.copy()\n", + "y = iris.target.copy()\n", + "names = iris.target_names.copy()\n", + "print(names)\n", + "\n", + "# bool屏蔽\n", + "y[np.random.choice([True,False],p=[0.3,0.7],size =len(y))]=-1\n", + "# y[:10]\n", + "\n", + "# 标签传递算法\n", + "from sklearn import semi_supervised\n", + "lp = semi_supervised.LabelPropagation()\n", + "lp.fit(X,y)\n", + "\n", + "preds = lp.predict(X)\n", + "print((preds == iris.target).mean())\n", + "\n", + "# 标签扩充算法\n", + "ls = semi_supervised.LabelSpreading()\n", + "ls.fit(X,y)\n", + "print((ls.predict(X)==iris.target).mean())" ] } ]