diff --git a/pytorch/官方教程/08.md b/pytorch/官方教程/08.md index e6867d85..23eb1f08 100644 --- a/pytorch/官方教程/08.md +++ b/pytorch/官方教程/08.md @@ -4,7 +4,6 @@ 作者:Jeremy Howard,[fast.ai](https://www.fast.ai)。 感谢 Rachel Thomas 和 Francisco Ingham。 -我们建议将本教程作为笔记本而不是脚本来运行。 要下载笔记本(`.ipynb`)文件,请单击页面顶部的链接。 PyTorch 提供设计精美的模块和类[`torch.nn`](https://pytorch.org/docs/stable/nn.html),[`torch.optim`](https://pytorch.org/docs/stable/optim.html),[`Dataset`](https://pytorch.org/docs/stable/data.html?highlight=dataset#torch.utils.data.Dataset)和[`DataLoader`](https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader)神经网络。 为了充分利用它们的功能并针对您的问题对其进行自定义,您需要真正了解它们在做什么。 为了建立这种理解,我们将首先在 MNIST 数据集上训练基本神经网络,而无需使用这些模型的任何功能。 我们最初将仅使用最基本的 PyTorch 张量函数。 然后,我们将一次从`torch.nn`,`torch.optim`,`Dataset`或`DataLoader`中逐个添加一个函数,以准确显示每个函数,以及如何使代码更简洁或更有效。 灵活。 diff --git a/pytorch/实战/8.ipynb b/pytorch/实战/8.ipynb new file mode 100644 index 00000000..31462a42 --- /dev/null +++ b/pytorch/实战/8.ipynb @@ -0,0 +1,242 @@ +{ + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0-final" + }, + "orig_nbformat": 2, + "kernelspec": { + "name": "python3", + "display_name": "Python 3.8.0 64-bit", + "metadata": { + "interpreter": { + "hash": "38740d3277777e2cd7c6c2cc9d8addf5118fdf3f82b1b39231fd12aeac8aee8b" + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 2, + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# minist数据集\n", + "from pathlib import Path\n", + "import requests\n", + "DATA_PATH = Path(\"data\")\n", + "PATH = DATA_PATH/\"minist\"\n", + "\n", + "PATH.mkdir(parents=True,exist_ok=True)\n", + "\n", + "URL = \"https://github.com/pytorch/tutorials/raw/master/_static/\"\n", + "\n", + "FILENAME = \"mnist.pkl.gz\"\n", + "\n", + "if not (PATH / FILENAME).exists():\n", + " content = requests.get(URL + FILENAME).content\n", + " (PATH / FILENAME).open(\"wb\").write(content)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import pickle \n", + "import gzip \n", + "with gzip.open((PATH / FILENAME).as_posix(), \"rb\") as f:\n", + " ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(50000, 784)\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": "
", + "image/svg+xml": "\r\n\r\n\r\n\r\n \r\n \r\n \r\n \r\n 2021-03-23T19:21:21.058038\r\n image/svg+xml\r\n \r\n \r\n Matplotlib v3.3.2, https://matplotlib.org/\r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n\r\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAN8klEQVR4nO3df6jVdZ7H8ddrbfojxzI39iZOrWOEUdE6i9nSyjYRTj8o7FYMIzQ0JDl/JDSwyIb7xxSLIVu6rBSDDtXYMus0UJHFMNVm5S6BdDMrs21qoxjlphtmmv1a9b1/3K9xp+75nOs53/PD+34+4HDO+b7P93zffPHl99f53o8jQgAmvj/rdQMAuoOwA0kQdiAJwg4kQdiBJE7o5sJsc+of6LCI8FjT29qy277C9lu237F9ezvfBaCz3Op1dtuTJP1B0gJJOyW9JGlRROwozMOWHeiwTmzZ50l6JyLejYgvJf1G0sI2vg9AB7UT9hmS/jjq/c5q2p+wvcT2kO2hNpYFoE0dP0EXEeskrZPYjQd6qZ0t+y5JZ4x6/51qGoA+1E7YX5J0tu3v2j5R0o8kbaynLQB1a3k3PiIO2V4q6SlJkyQ9EBFv1NYZgFq1fOmtpYVxzA50XEd+VAPg+EHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEi0P2Yzjw6RJk4r1U045paPLX7p0acPaSSedVJx39uzZxfqtt95arN9zzz0Na4sWLSrO+/nnnxfrK1euLNbvvPPOYr0X2gq77fckHZB0WNKhiJhbR1MA6lfHlv3SiPiwhu8B0EEcswNJtBv2kPS07ZdtLxnrA7aX2B6yPdTmsgC0od3d+PkRscv2X0h6xvZ/R8Tm0R+IiHWS1kmS7WhzeQBa1NaWPSJ2Vc97JD0maV4dTQGoX8thtz3Z9pSjryX9QNL2uhoDUK92duMHJD1m++j3/HtE/L6WriaYM888s1g/8cQTi/WLL764WJ8/f37D2tSpU4vzXn/99cV6L+3cubNYX7NmTbE+ODjYsHbgwIHivK+++mqx/sILLxTr/ajlsEfEu5L+qsZeAHQQl96AJAg7kARhB5Ig7EAShB1IwhHd+1HbRP0F3Zw5c4r1TZs2Feudvs20Xx05cqRYv/nmm4v1Tz75pOVlDw8PF+sfffRRsf7WW2+1vOxOiwiPNZ0tO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwXX2GkybNq1Y37JlS7E+a9asOtupVbPe9+3bV6xfeumlDWtffvllcd6svz9oF9fZgeQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJhmyuwd69e4v1ZcuWFetXX311sf7KK68U683+pHLJtm3bivUFCxYU6wcPHizWzzvvvIa12267rTgv6sWWHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeS4H72PnDyyScX682GF167dm3D2uLFi4vz3njjjcX6hg0binX0n5bvZ7f9gO09trePmjbN9jO2366eT62zWQD1G89u/K8kXfG1abdLejYizpb0bPUeQB9rGvaI2Czp678HXShpffV6vaRr620LQN1a/W38QEQcHSzrA0kDjT5oe4mkJS0uB0BN2r4RJiKidOItItZJWidxgg7opVYvve22PV2Squc99bUEoBNaDftGSTdVr2+S9Hg97QDolKa78bY3SPq+pNNs75T0c0krJf3W9mJJ70v6YSebnOj279/f1vwff/xxy/PecsstxfrDDz9crDcbYx39o2nYI2JRg9JlNfcCoIP4uSyQBGEHkiDsQBKEHUiCsANJcIvrBDB58uSGtSeeeKI47yWXXFKsX3nllcX6008/Xayj+xiyGUiOsANJEHYgCcIOJEHYgSQIO5AEYQeS4Dr7BHfWWWcV61u3bi3W9+3bV6w/99xzxfrQ0FDD2n333Vect5v/NicSrrMDyRF2IAnCDiRB2IEkCDuQBGEHkiDsQBJcZ09ucHCwWH/wwQeL9SlTprS87OXLlxfrDz30ULE+PDxcrGfFdXYgOcIOJEHYgSQIO5AEYQeSIOxAEoQdSILr7Cg6//zzi/XVq1cX65dd1vpgv2vXri3WV6xYUazv2rWr5WUfz1q+zm77Adt7bG8fNe0O27tsb6seV9XZLID6jWc3/leSrhhj+r9ExJzq8bt62wJQt6Zhj4jNkvZ2oRcAHdTOCbqltl+rdvNPbfQh20tsD9lu/MfIAHRcq2H/haSzJM2RNCxpVaMPRsS6iJgbEXNbXBaAGrQU9ojYHRGHI+KIpF9KmldvWwDq1lLYbU8f9XZQ0vZGnwXQH5peZ7e9QdL3JZ0mabekn1fv50gKSe9J+mlENL25mOvsE8/UqVOL9WuuuaZhrdm98vaYl4u/smnTpmJ9wYIFxfpE1eg6+wnjmHHRGJPvb7sjAF3Fz2WBJAg7kARhB5Ig7EAShB1Igltc0TNffPFFsX7CCeWLRYcOHSrWL7/88oa1559/vjjv8Yw/JQ0kR9iBJAg7kARhB5Ig7EAShB1IgrADSTS96w25XXDBBcX6DTfcUKxfeOGFDWvNrqM3s2PHjmJ98+bNbX3/RMOWHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeS4Dr7BDd79uxifenSpcX6ddddV6yffvrpx9zTeB0+fLhYHx4u//XyI0eO1NnOcY8tO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwXX240Cza9mLFo010O6IZtfRZ86c2UpLtRgaGirWV6xYUaxv3LixznYmvKZbdttn2H7O9g7bb9i+rZo+zfYztt+unk/tfLsAWjWe3fhDkv4+Is6V9DeSbrV9rqTbJT0bEWdLerZ6D6BPNQ17RAxHxNbq9QFJb0qaIWmhpPXVx9ZLurZDPQKowTEds9ueKel7krZIGoiIoz9O/kDSQIN5lkha0kaPAGow7rPxtr8t6RFJP4uI/aNrMTI65JiDNkbEuoiYGxFz2+oUQFvGFXbb39JI0H8dEY9Wk3fbnl7Vp0va05kWAdSh6W68bUu6X9KbEbF6VGmjpJskrayeH+9IhxPAwMCYRzhfOffcc4v1e++9t1g/55xzjrmnumzZsqVYv/vuuxvWHn+8/E+GW1TrNZ5j9r+V9GNJr9veVk1brpGQ/9b2YknvS/phRzoEUIumYY+I/5I05uDuki6rtx0AncLPZYEkCDuQBGEHkiDsQBKEHUiCW1zHadq0aQ1ra9euLc47Z86cYn3WrFmttFSLF198sVhftWpVsf7UU08V65999tkx94TOYMsOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0mkuc5+0UUXFevLli0r1ufNm9ewNmPGjJZ6qsunn37asLZmzZrivHfddVexfvDgwZZ6Qv9hyw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSaS5zj44ONhWvR07duwo1p988sli/dChQ8V66Z7zffv2FedFHmzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJR0T5A/YZkh6SNCApJK2LiH+1fYekWyT9b/XR5RHxuybfVV4YgLZFxJijLo8n7NMlTY+IrbanSHpZ0rUaGY/9k4i4Z7xNEHag8xqFfTzjsw9LGq5eH7D9pqTe/mkWAMfsmI7Zbc+U9D1JW6pJS22/ZvsB26c2mGeJ7SHbQ+21CqAdTXfjv/qg/W1JL0haERGP2h6Q9KFGjuP/SSO7+jc3+Q5244EOa/mYXZJsf0vSk5KeiojVY9RnSnoyIs5v8j2EHeiwRmFvuhtv25Lul/Tm6KBXJ+6OGpS0vd0mAXTOeM7Gz5f0n5Jel3Skmrxc0iJJczSyG/+epJ9WJ/NK38WWHeiwtnbj60LYgc5reTcewMRA2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSKLbQzZ/KOn9Ue9Pq6b1o37trV/7kuitVXX29peNCl29n/0bC7eHImJuzxoo6Nfe+rUvid5a1a3e2I0HkiDsQBK9Dvu6Hi+/pF9769e+JHprVVd66+kxO4Du6fWWHUCXEHYgiZ6E3fYVtt+y/Y7t23vRQyO237P9uu1tvR6frhpDb4/t7aOmTbP9jO23q+cxx9jrUW932N5Vrbtttq/qUW9n2H7O9g7bb9i+rZre03VX6Ksr663rx+y2J0n6g6QFknZKeknSoojY0dVGGrD9nqS5EdHzH2DY/jtJn0h66OjQWrb/WdLeiFhZ/Ud5akT8Q5/0doeOcRjvDvXWaJjxn6iH667O4c9b0Yst+zxJ70TEuxHxpaTfSFrYgz76XkRslrT3a5MXSlpfvV6vkX8sXdegt74QEcMRsbV6fUDS0WHGe7ruCn11RS/CPkPSH0e936n+Gu89JD1t+2XbS3rdzBgGRg2z9YGkgV42M4amw3h309eGGe+bddfK8Oft4gTdN82PiL+WdKWkW6vd1b4UI8dg/XTt9BeSztLIGIDDklb1splqmPFHJP0sIvaPrvVy3Y3RV1fWWy/CvkvSGaPef6ea1hciYlf1vEfSYxo57Ognu4+OoFs97+lxP1+JiN0RcTgijkj6pXq47qphxh+R9OuIeLSa3PN1N1Zf3VpvvQj7S5LOtv1d2ydK+pGkjT3o4xtsT65OnMj2ZEk/UP8NRb1R0k3V65skPd7DXv5Evwzj3WiYcfV43fV8+POI6PpD0lUaOSP/P5L+sRc9NOhrlqRXq8cbve5N0gaN7Nb9n0bObSyW9OeSnpX0tqT/kDStj3r7N40M7f2aRoI1vUe9zdfILvprkrZVj6t6ve4KfXVlvfFzWSAJTtABSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBL/DyJ7caZa7LphAAAAAElFTkSuQmCC\n" + }, + "metadata": { + "needs_background": "light" + } + } + ], + "source": [ + "from matplotlib import pyplot as plt \n", + "plt.imshow(x_train[0].reshape((28,28)),cmap=\"gray\")\n", + "print(x_train.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + ":3: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n x_train,y_train,x_valid,y_valid = map(torch.tensor,(x_train,y_train,x_valid,y_valid))\n" + ] + } + ], + "source": [ + "import torch \n", + "# 对列表中的每一个对象应用torch.tensor()方法\n", + "x_train,y_train,x_valid,y_valid = map(torch.tensor,(x_train,y_train,x_valid,y_valid))\n", + "\n", + "n,c = x_train.shape \n", + "# x_train,x_train.shape,y_train.min(),y_train.max()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([64, 784])\n", + "torch.Size([64, 10])\n", + "tensor([-2.0201, -2.3565, -2.5740, -2.7446, -1.8785, -1.9676, -2.3419, -2.5541,\n", + " -2.3145, -2.7065], grad_fn=)\n", + "tensor(0.1562)\n", + "tensor(0.0827, grad_fn=) tensor(1.)\n" + ] + } + ], + "source": [ + "# 从零开始构建神经网络。\n", + "# pytorch中的_尾随操作,表示原地执行,修改原对象,而不是返回一份copy\n", + "import math \n", + "weights = torch.randn(784,10)/math.sqrt(784)\n", + "weights.requires_grad_()\n", + "bias = torch.zeros(10,requires_grad=True)\n", + "\n", + "# 使用log softmax方法,并添加到第二列\n", + "def log_softmax(x):\n", + " return x - x.exp().sum(-1).log().unsqueeze(-1)\n", + "\n", + "def softmax(x):\n", + " return x.exp()/(x.exp().sum(-1))\n", + "\n", + "test_x = torch.tensor([1,2,3,4])\n", + "# print(log_softmax(test_x))\n", + "test_x = torch.tensor([1,2,3,4])\n", + "# print(softmax(test_x))\n", + "\n", + "# 定义了一个简单的一维线性模型\n", + "def model(xb):\n", + " return log_softmax(xb @ weights +bias )\n", + "\n", + "# 使用模型进行预测\n", + "bs = 64 \n", + "xb = x_train[0:bs]\n", + "print(xb.shape)\n", + "preds = model(xb)\n", + "print(preds.shape)\n", + "print(preds[0])\n", + "\n", + "# 定义损失函数\n", + "def loss_func(input,target):\n", + " return -input[range(target.shape[0]),target].mean()\n", + "\n", + "yb = y_train[0:bs]\n", + "\n", + "\n", + "def accuracy(out,yb):\n", + " preds = torch.argmax(out,dim=1)\n", + " return (preds==yb).float().mean()\n", + "\n", + "print(accuracy(preds,yb ))\n", + "\n", + "lr = 0.5\n", + "epochs = 2 \n", + "for epoch in range(epochs):\n", + " for i in range((n - 1) // bs + 1):\n", + " # set_trace()\n", + " start_i = i * bs\n", + " end_i = start_i + bs\n", + " xb = x_train[start_i:end_i]\n", + " yb = y_train[start_i:end_i]\n", + " pred = model(xb)\n", + " loss = loss_func(pred, yb)\n", + "\n", + " loss.backward()\n", + " with torch.no_grad():\n", + " weights -= weights.grad * lr\n", + " bias -= bias.grad * lr\n", + " weights.grad.zero_()\n", + " bias.grad.zero_()\n", + "\n", + "print(loss_func(model(xb),yb),accuracy(model(xb),yb))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 使用torch.nn.functional" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 使用CNN\n", + "from torch import nn \n", + "from torch import optim \n", + "class mnist_cnn(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv1 = nn.Conv2d(1,16,kernel_size=3,stride=2,padding=1)\n", + " self.conv2 = nn.Conv2d(16,16.kernel_size=3,stride=2,padding=1)\n", + " self.conv3 = nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1)\n", + "\n", + " def forward(self, xb):\n", + " xb = xb.view(-1, 1, 28, 28)\n", + " xb = F.relu(self.conv1(xb))\n", + " xb = F.relu(self.conv2(xb))\n", + " xb = F.relu(self.conv3(xb))\n", + " xb = F.avg_pool2d(xb, 4)\n", + " return xb.view(-1, xb.size(1))\n", + "\n", + "lr = 0.1\n", + "\n", + "model = mnist_cnn()\n", + "opt = optim.SGD(model.parameters(),lr=lr,momentum=0.9)\n", + "fit(epochs,model,loss_func,opt,train_dl,valid_dl)" + ] + } + ] +} \ No newline at end of file