2020-10-19 21:48:57

This commit is contained in:
wizardforcel
2020-10-19 21:48:57 +08:00
parent 74f7d35aeb
commit 045dee5888
20 changed files with 73 additions and 73 deletions

View File

@@ -9,7 +9,7 @@
这份 GPU 的代码是依据[之前这份CNN](https://www.pytorchtutorial.com/goto/https://github.com/MorvanZhou/PyTorch-Tutorial/blob/master/tutorial-contents/401_CNN.py)的代码修改的. 大概修改的地方包括将数据的形式变成 GPU 能读的形式, 然后将 CNN 也变成 GPU 能读的形式. 做法就是在后面加上 .cuda() , 很简单.
```
```py
...
test_data = torchvision.datasets.MNIST(root=\'./mnist/\', train=False)
@@ -21,7 +21,7 @@ test_y = test_data.test_labels[:2000].cuda()
再来把我们的 CNN 参数也变成 GPU 兼容形式.
```
```py
class CNN(nn.Module):
...
@@ -33,7 +33,7 @@ cnn.cuda() # Moves all model parameters and buffers to the GPU.
然后就是在 train 的时候, 将每次的training data 变成 GPU 形式. .cuda()
```
```py
for epoch ..:
for step, ...:
# !!!!!!!! 这里有修改 !!!!!!!!! #