mirror of
https://github.com/apachecn/ailearning.git
synced 2026-04-25 11:01:15 +08:00
2020-10-19 21:48:57
This commit is contained in:
@@ -8,7 +8,7 @@
|
||||
|
||||
## MNIST手写数据
|
||||
|
||||
```
|
||||
```py
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
@@ -42,7 +42,7 @@ train_data = torchvision.datasets.MNIST(
|
||||
|
||||
同样, 我们除了训练数据, 还给一些测试数据, 测试看看它有没有训练好.
|
||||
|
||||
```
|
||||
```py
|
||||
test_data = torchvision.datasets.MNIST(root=\\'./mnist/\\', train=False)
|
||||
|
||||
# 批训练 50samples, 1 channel, 28x28 (50, 1, 28, 28)
|
||||
@@ -63,7 +63,7 @@ test_y = test_data.test_labels[:2000]
|
||||
4. (inputN, stateN)-> LSTM -> (outputN, stateN 1) ;
|
||||
5. outputN -> Linear -> prediction . 通过LSTM分析每一时刻的值, 并且将这一时刻和前面时刻的理解合并在一起, 生成当前时刻对前面数据的理解或记忆. 传递这种理解给下一时刻分析.
|
||||
|
||||
```
|
||||
```py
|
||||
class RNN(nn.Module):
|
||||
def __init__(self):
|
||||
super(RNN, self).__init__()
|
||||
@@ -103,7 +103,7 @@ RNN (
|
||||
|
||||
我们将图片数据看成一个时间上的连续数据, 每一行的像素点都是这个时刻的输入, 读完整张图片就是从上而下的读完了每行的像素点. 然后我们就可以拿出 RNN 在最后一步的分析值判断图片是哪一类了. 下面的代码省略了计算 accuracy 的部分, 你可以在我的 github 中看到全部代码.
|
||||
|
||||
```
|
||||
```py
|
||||
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR) # optimize all parameters
|
||||
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
|
||||
|
||||
@@ -129,7 +129,7 @@ Epoch: 0 | train loss: 0.1868 | test accuracy: 0.96
|
||||
|
||||
最后我们再来取10个数据, 看看预测的值到底对不对:
|
||||
|
||||
```
|
||||
```py
|
||||
test_output = rnn(test_x[:10].view(-1, 28, 28))
|
||||
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
|
||||
print(pred_y, \\'prediction number\\')
|
||||
|
||||
Reference in New Issue
Block a user