Files
ailearning/docs/pytorch/24.md
2020-10-19 22:31:47 +08:00

66 lines
2.3 KiB
Markdown
Raw Permalink Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 5.2 GPU 加速运算
在 GPU 训练可以大幅提升运算速度. 而且 Torch 也有一套很好的 GPU 运算体系. 但是要强调的是:
* 你的电脑里有合适的 GPU 显卡(NVIDIA), 且支持 CUDA 模块. [请在NVIDIA官网查询](https://developer.nvidia.com/cuda-gpus)
* 必须安装 GPU 版的 Torch, [点击这里查看如何安装](https://www.pytorchtutorial.com/1-2-install-pytorch/)
## 用 GPU 训练 CNN
这份 GPU 的代码是依据[之前这份CNN](https://github.com/MorvanZhou/PyTorch-Tutorial/blob/master/tutorial-contents/401_CNN.py)的代码修改的. 大概修改的地方包括将数据的形式变成 GPU 能读的形式, 然后将 CNN 也变成 GPU 能读的形式. 做法就是在后面加上 .cuda() , 很简单.
```py
...
test_data = torchvision.datasets.MNIST(root=\'./mnist/\', train=False)
# !!!!!!!! 修改 test data 形式 !!!!!!!!! #
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1)).type(torch.FloatTensor)[:2000].cuda()/255\. # Tensor on GPU
test_y = test_data.test_labels[:2000].cuda()
```
再来把我们的 CNN 参数也变成 GPU 兼容形式.
```py
class CNN(nn.Module):
...
cnn = CNN()
# !!!!!!!! 转换 cnn 去 CUDA !!!!!!!!! #
cnn.cuda() # Moves all model parameters and buffers to the GPU.
```
然后就是在 train 的时候, 将每次的training data 变成 GPU 形式. .cuda()
```py
for epoch ..:
for step, ...:
# !!!!!!!! 这里有修改 !!!!!!!!! #
b_x = Variable(x).cuda() # Tensor on GPU
b_y = Variable(y).cuda() # Tensor on GPU
...
if step % 50 == 0:
test_output = cnn(test_x)
# !!!!!!!! 这里有修改 !!!!!!!!! #
pred_y = torch.max(test_output, 1)[1].cuda().data.squeeze() # 将操作放去 GPU
accuracy = torch.sum(pred_y == test_y) / test_y.size(0)
...
test_output = cnn(test_x[:10])
# !!!!!!!! 这里有修改 !!!!!!!!! #
pred_y = torch.max(test_output, 1)[1].cuda().data.squeeze() # 将操作放去 GPU
...
print(test_y[:10], \'real number\')
```
大功告成~
所以这也就是在我 [github 代码](https://github.com/MorvanZhou/PyTorch-Tutorial/blob/master/tutorial-contents/502_GPU.py) 中的每一步的意义啦.
文章来源:[莫烦](https://morvanzhou.github.io/)