mirror of
https://github.com/openmlsys/openmlsys-zh.git
synced 2026-04-04 19:28:32 +08:00
fix ch03-2 (#192)
This commit is contained in:
@@ -153,7 +153,7 @@ $$
|
||||
\nabla\boldsymbol{W_2} &= \boldsymbol{X_2}^\top\nabla\boldsymbol{Y} \\
|
||||
\nabla\boldsymbol{X_1} &= \nabla\boldsymbol{X_2}\boldsymbol{W_1}^\top = (\nabla\boldsymbol{Y}\boldsymbol{W_2}^\top)\boldsymbol{W_1}^\top \\
|
||||
\nabla\boldsymbol{W_1} &= \boldsymbol{X_1}^\top\nabla\boldsymbol{X_2} = \boldsymbol{X_1}^\top(\nabla\boldsymbol{Y}\boldsymbol{W_2}^\top) \\
|
||||
\nabla\boldsymbol{Y} &= \nabla\boldsymbol{X_1}\boldsymbol{W}^\top = ((\nabla\boldsymbol{Y}\boldsymbol{W_2}^\top)\boldsymbol{W_1}^\top)\boldsymbol{W}^\top \\
|
||||
\nabla\boldsymbol{X} &= \nabla\boldsymbol{X_1}\boldsymbol{W}^\top = ((\nabla\boldsymbol{Y}\boldsymbol{W_2}^\top)\boldsymbol{W_1}^\top)\boldsymbol{W}^\top \\
|
||||
\nabla\boldsymbol{W} &= \boldsymbol{X}^\top\nabla\boldsymbol{X_1} = \boldsymbol{X}^\top((\nabla\boldsymbol{Y}\boldsymbol{W_2}^\top)\boldsymbol{W_1}^\top)
|
||||
\end{aligned}
|
||||
$$
|
||||
@@ -162,12 +162,12 @@ $$
|
||||
|
||||
根据上述公式我们可以得出循环控制的反向梯度计算过程如下,在下面代码中伪变量的前缀*grad*代表变量梯度变量,*transpose*代表矩阵转置算子。
|
||||
```python
|
||||
grad_Y2 = matmul(grad_Y3, transpose(W2))
|
||||
grad_W2 = matmul(transpose(Y2), grad_Y3)
|
||||
grad_Y1 = matmul(grad_Y2, transpose(W1))
|
||||
grad_W1 = matmul(transpose(Y1), grad_Y2)
|
||||
grad_Y = matmul(grad_Y1, transpose(W))
|
||||
grad_W = matmul(transpose(Y), grad_Y1)
|
||||
grad_X2 = matmul(grad_Y, transpose(W2))
|
||||
grad_W2 = matmul(transpose(X2), grad_Y)
|
||||
grad_X1 = matmul(grad_X2, transpose(W1))
|
||||
grad_W1 = matmul(transpose(X1), grad_X2)
|
||||
grad_X = matmul(grad_X1, transpose(W))
|
||||
grad_W = matmul(transpose(X), grad_X1)
|
||||
```
|
||||
结合公式、代码以及 :numref:`chain`我们可以看出,在反向传播过程中使用到前向传播的中间变量。因此保存网络中间层输出状态和中间变量,尽管占用了部分内存但能够复用计算结果,达到了提高反向传播计算效率的目的。
|
||||
|
||||
|
||||
Reference in New Issue
Block a user