2020-10-19 21:48:57

This commit is contained in:
wizardforcel
2020-10-19 21:48:57 +08:00
parent 74f7d35aeb
commit 045dee5888
20 changed files with 73 additions and 73 deletions

View File

@@ -8,7 +8,7 @@
## MNIST手写数据
```
```py
import torch
from torch import nn
from torch.autograd import Variable
@@ -42,7 +42,7 @@ train_data = torchvision.datasets.MNIST(
同样, 我们除了训练数据, 还给一些测试数据, 测试看看它有没有训练好.
```
```py
test_data = torchvision.datasets.MNIST(root=\\'./mnist/\\', train=False)
# 批训练 50samples, 1 channel, 28x28 (50, 1, 28, 28)
@@ -63,7 +63,7 @@ test_y = test_data.test_labels[:2000]
4. (inputN, stateN)-> LSTM -> (outputN, stateN 1) ;
5. outputN -> Linear -> prediction . 通过LSTM分析每一时刻的值, 并且将这一时刻和前面时刻的理解合并在一起, 生成当前时刻对前面数据的理解或记忆. 传递这种理解给下一时刻分析.
```
```py
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
@@ -103,7 +103,7 @@ RNN (
我们将图片数据看成一个时间上的连续数据, 每一行的像素点都是这个时刻的输入, 读完整张图片就是从上而下的读完了每行的像素点. 然后我们就可以拿出 RNN 在最后一步的分析值判断图片是哪一类了. 下面的代码省略了计算 accuracy 的部分, 你可以在我的 github 中看到全部代码.
```
```py
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR) # optimize all parameters
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
@@ -129,7 +129,7 @@ Epoch: 0 | train loss: 0.1868 | test accuracy: 0.96
最后我们再来取10个数据, 看看预测的值到底对不对:
```
```py
test_output = rnn(test_x[:10].view(-1, 28, 28))
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
print(pred_y, \\'prediction number\\')