mirror of
https://github.com/apachecn/ailearning.git
synced 2026-04-24 18:42:25 +08:00
2020-10-19 21:48:57
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
|
||||

|
||||
|
||||
```
|
||||
```py
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
@@ -34,7 +34,7 @@ DOWNLOAD_MNIST = False # set to True if haven\'t download the data
|
||||
|
||||
这一次的 RNN, 我们对每一个 r_out 都得放到 Linear 中去计算出预测的 output , 所以我们能用一个 for loop 来循环计算. **这点是 Tensorflow 望尘莫及的!** 除了这点, 还有一些动态的过程都可以在这个教程中查看, 看看我们的 PyTorch 和 Tensorflow 到底哪家强.
|
||||
|
||||
```
|
||||
```py
|
||||
class RNN(nn.Module):
|
||||
def __init__(self):
|
||||
super(RNN, self).__init__()
|
||||
@@ -70,7 +70,7 @@ RNN (
|
||||
|
||||
其实熟悉 RNN 的朋友应该知道, forward 过程中的对每个时间点求输出还有一招使得计算量比较小的. 不过上面的内容主要是为了呈现 PyTorch 在动态构图上的优势, 所以我用了一个 for loop 来搭建那套输出系统. 下面介绍一个替换方式. 使用 reshape 的方式整批计算.
|
||||
|
||||
```
|
||||
```py
|
||||
def forward(self, x, h_state):
|
||||
r_out, h_state = self.rnn(x, h_state)
|
||||
r_out_reshaped = r_out.view(-1, HIDDEN_SIZE) # to 2D data
|
||||
@@ -84,7 +84,7 @@ def forward(self, x, h_state):
|
||||
|
||||

|
||||
|
||||
```
|
||||
```py
|
||||
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR) # optimize all rnn parameters
|
||||
loss_func = nn.MSELoss()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user