mirror of
https://github.com/apachecn/ailearning.git
synced 2026-04-24 02:23:45 +08:00
2020-10-19 21:48:57
This commit is contained in:
@@ -9,7 +9,7 @@
|
||||
|
||||
这份 GPU 的代码是依据[之前这份CNN](https://www.pytorchtutorial.com/goto/https://github.com/MorvanZhou/PyTorch-Tutorial/blob/master/tutorial-contents/401_CNN.py)的代码修改的. 大概修改的地方包括将数据的形式变成 GPU 能读的形式, 然后将 CNN 也变成 GPU 能读的形式. 做法就是在后面加上 .cuda() , 很简单.
|
||||
|
||||
```
|
||||
```py
|
||||
...
|
||||
|
||||
test_data = torchvision.datasets.MNIST(root=\'./mnist/\', train=False)
|
||||
@@ -21,7 +21,7 @@ test_y = test_data.test_labels[:2000].cuda()
|
||||
|
||||
再来把我们的 CNN 参数也变成 GPU 兼容形式.
|
||||
|
||||
```
|
||||
```py
|
||||
class CNN(nn.Module):
|
||||
...
|
||||
|
||||
@@ -33,7 +33,7 @@ cnn.cuda() # Moves all model parameters and buffers to the GPU.
|
||||
|
||||
然后就是在 train 的时候, 将每次的training data 变成 GPU 形式. .cuda()
|
||||
|
||||
```
|
||||
```py
|
||||
for epoch ..:
|
||||
for step, ...:
|
||||
# !!!!!!!! 这里有修改 !!!!!!!!! #
|
||||
|
||||
Reference in New Issue
Block a user