2020-10-19 21:48:57

This commit is contained in:
wizardforcel
2020-10-19 21:48:57 +08:00
parent 74f7d35aeb
commit 045dee5888
20 changed files with 73 additions and 73 deletions

View File

@@ -8,7 +8,7 @@
## MNIST手写数据
```
```py
import torch
import torch.nn as nn
from torch.autograd import Variable
@@ -40,7 +40,7 @@ train_data = torchvision.datasets.MNIST(
同样, 我们除了训练数据, 还给一些测试数据, 测试看看它有没有训练好.
```
```py
test_data = torchvision.datasets.MNIST(root=\\'./mnist/\\', train=False)
# 批训练 50samples, 1 channel, 28x28 (50, 1, 28, 28)
@@ -55,7 +55,7 @@ test_y = test_data.test_labels[:2000]
和以前一样, 我们用一个 class 来建立 CNN 模型. 这个 CNN 整体流程是 卷积( Conv2d ) -> 激励函数( ReLU ) -> 池化, 向下采样 ( MaxPooling ) -> 再来一遍 -> 展平多维的卷积成的特征图 -> 接入全连接层 ( Linear ) -> 输出
```
```py
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
@@ -107,7 +107,7 @@ CNN (
下面我们开始训练, 将  y 都用 Variable 包起来, 然后放入 cnn 中计算 output, 最后再计算误差. 下面代码省略了计算精确度 accuracy 的部分, 如果想细看 accuracy 代码的同学, 请去往我的 github 看全部代码.
```
```py
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
@@ -134,7 +134,7 @@ Epoch: 0 | train loss: 0.0078 | test accuracy: 0.98
最后我们再来取10个数据, 看看预测的值到底对不对:
```
```py
test_output = cnn(test_x[:10])
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
print(pred_y, \\'prediction number\\')