mirror of
https://github.com/apachecn/ailearning.git
synced 2026-04-24 18:42:25 +08:00
2020-10-19 21:48:57
This commit is contained in:
@@ -8,7 +8,7 @@
|
||||
|
||||
## MNIST手写数据
|
||||
|
||||
```
|
||||
```py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
@@ -40,7 +40,7 @@ train_data = torchvision.datasets.MNIST(
|
||||
|
||||
同样, 我们除了训练数据, 还给一些测试数据, 测试看看它有没有训练好.
|
||||
|
||||
```
|
||||
```py
|
||||
test_data = torchvision.datasets.MNIST(root=\\'./mnist/\\', train=False)
|
||||
|
||||
# 批训练 50samples, 1 channel, 28x28 (50, 1, 28, 28)
|
||||
@@ -55,7 +55,7 @@ test_y = test_data.test_labels[:2000]
|
||||
|
||||
和以前一样, 我们用一个 class 来建立 CNN 模型. 这个 CNN 整体流程是 卷积( Conv2d ) -> 激励函数( ReLU ) -> 池化, 向下采样 ( MaxPooling ) -> 再来一遍 -> 展平多维的卷积成的特征图 -> 接入全连接层 ( Linear ) -> 输出
|
||||
|
||||
```
|
||||
```py
|
||||
class CNN(nn.Module):
|
||||
def __init__(self):
|
||||
super(CNN, self).__init__()
|
||||
@@ -107,7 +107,7 @@ CNN (
|
||||
|
||||
下面我们开始训练, 将 y 都用 Variable 包起来, 然后放入 cnn 中计算 output, 最后再计算误差. 下面代码省略了计算精确度 accuracy 的部分, 如果想细看 accuracy 代码的同学, 请去往我的 github 看全部代码.
|
||||
|
||||
```
|
||||
```py
|
||||
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) # optimize all cnn parameters
|
||||
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
|
||||
|
||||
@@ -134,7 +134,7 @@ Epoch: 0 | train loss: 0.0078 | test accuracy: 0.98
|
||||
|
||||
最后我们再来取10个数据, 看看预测的值到底对不对:
|
||||
|
||||
```
|
||||
```py
|
||||
test_output = cnn(test_x[:10])
|
||||
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
|
||||
print(pred_y, \\'prediction number\\')
|
||||
|
||||
Reference in New Issue
Block a user