2020-10-19 21:48:57

This commit is contained in:
wizardforcel
2020-10-19 21:48:57 +08:00
parent 74f7d35aeb
commit 045dee5888
20 changed files with 73 additions and 73 deletions

View File

@@ -14,7 +14,7 @@
![](img/22309cd02ee52b3a65e1f0022e8b964e.png)
```
```py
import torch
from torch import nn
from torch.autograd import Variable
@@ -34,7 +34,7 @@ DOWNLOAD_MNIST = False # set to True if haven\'t download the data
这一次的 RNN, 我们对每一个 r_out  都得放到 Linear  中去计算出预测的 output , 所以我们能用一个 for loop 来循环计算. **这点是 Tensorflow 望尘莫及的!** 除了这点, 还有一些动态的过程都可以在这个教程中查看, 看看我们的 PyTorch 和 Tensorflow 到底哪家强.
```
```py
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
@@ -70,7 +70,7 @@ RNN (
其实熟悉 RNN 的朋友应该知道, forward  过程中的对每个时间点求输出还有一招使得计算量比较小的. 不过上面的内容主要是为了呈现 PyTorch 在动态构图上的优势, 所以我用了一个 for loop  来搭建那套输出系统. 下面介绍一个替换方式. 使用 reshape 的方式整批计算.
```
```py
def forward(self, x, h_state):
r_out, h_state = self.rnn(x, h_state)
r_out_reshaped = r_out.view(-1, HIDDEN_SIZE) # to 2D data
@@ -84,7 +84,7 @@ def forward(self, x, h_state):
![](img/f38868821469cadc36810cfd827511d1.png)
```
```py
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR) # optimize all rnn parameters
loss_func = nn.MSELoss()