mirror of
https://github.com/apachecn/ailearning.git
synced 2026-02-11 22:35:35 +08:00
66 lines
2.3 KiB
Markdown
66 lines
2.3 KiB
Markdown
# 5.2 – GPU 加速运算
|
||
|
||
在 GPU 训练可以大幅提升运算速度. 而且 Torch 也有一套很好的 GPU 运算体系. 但是要强调的是:
|
||
|
||
* 你的电脑里有合适的 GPU 显卡(NVIDIA), 且支持 CUDA 模块. [请在NVIDIA官网查询](https://developer.nvidia.com/cuda-gpus)
|
||
* 必须安装 GPU 版的 Torch, [点击这里查看如何安装](https://www.pytorchtutorial.com/1-2-install-pytorch/)
|
||
|
||
## 用 GPU 训练 CNN
|
||
|
||
这份 GPU 的代码是依据[之前这份CNN](https://github.com/MorvanZhou/PyTorch-Tutorial/blob/master/tutorial-contents/401_CNN.py)的代码修改的. 大概修改的地方包括将数据的形式变成 GPU 能读的形式, 然后将 CNN 也变成 GPU 能读的形式. 做法就是在后面加上 .cuda() , 很简单.
|
||
|
||
```py
|
||
...
|
||
|
||
test_data = torchvision.datasets.MNIST(root=\'./mnist/\', train=False)
|
||
|
||
# !!!!!!!! 修改 test data 形式 !!!!!!!!! #
|
||
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1)).type(torch.FloatTensor)[:2000].cuda()/255\. # Tensor on GPU
|
||
test_y = test_data.test_labels[:2000].cuda()
|
||
```
|
||
|
||
再来把我们的 CNN 参数也变成 GPU 兼容形式.
|
||
|
||
```py
|
||
class CNN(nn.Module):
|
||
...
|
||
|
||
cnn = CNN()
|
||
|
||
# !!!!!!!! 转换 cnn 去 CUDA !!!!!!!!! #
|
||
cnn.cuda() # Moves all model parameters and buffers to the GPU.
|
||
```
|
||
|
||
然后就是在 train 的时候, 将每次的training data 变成 GPU 形式. .cuda()
|
||
|
||
```py
|
||
for epoch ..:
|
||
for step, ...:
|
||
# !!!!!!!! 这里有修改 !!!!!!!!! #
|
||
b_x = Variable(x).cuda() # Tensor on GPU
|
||
b_y = Variable(y).cuda() # Tensor on GPU
|
||
|
||
...
|
||
|
||
if step % 50 == 0:
|
||
test_output = cnn(test_x)
|
||
|
||
# !!!!!!!! 这里有修改 !!!!!!!!! #
|
||
pred_y = torch.max(test_output, 1)[1].cuda().data.squeeze() # 将操作放去 GPU
|
||
|
||
accuracy = torch.sum(pred_y == test_y) / test_y.size(0)
|
||
...
|
||
|
||
test_output = cnn(test_x[:10])
|
||
|
||
# !!!!!!!! 这里有修改 !!!!!!!!! #
|
||
pred_y = torch.max(test_output, 1)[1].cuda().data.squeeze() # 将操作放去 GPU
|
||
...
|
||
print(test_y[:10], \'real number\')
|
||
```
|
||
|
||
大功告成~
|
||
|
||
所以这也就是在我 [github 代码](https://github.com/MorvanZhou/PyTorch-Tutorial/blob/master/tutorial-contents/502_GPU.py) 中的每一步的意义啦.
|
||
|
||
文章来源:[莫烦](https://morvanzhou.github.io/) |