mirror of
https://github.com/apachecn/ailearning.git
synced 2026-04-24 02:23:45 +08:00
2020-10-19 21:48:57
This commit is contained in:
@@ -6,7 +6,7 @@
|
||||
|
||||
我们快速地建造数据, 搭建网络:
|
||||
|
||||
```
|
||||
```py
|
||||
torch.manual_seed(1) # reproducible
|
||||
|
||||
# 假数据
|
||||
@@ -35,7 +35,7 @@ def save():
|
||||
|
||||
接下来我们有两种途径来保存
|
||||
|
||||
```
|
||||
```py
|
||||
torch.save(net1, \'net.pkl\') # 保存整个网络
|
||||
torch.save(net1.state_dict(), \'net_params.pkl\') # 只保存网络中的参数 (速度快, 占内存少)
|
||||
```
|
||||
@@ -44,7 +44,7 @@ torch.save(net1.state_dict(), \'net_params.pkl\') # 只保存网络中的参
|
||||
|
||||
这种方式将会提取整个神经网络, 网络大的时候可能会比较慢.
|
||||
|
||||
```
|
||||
```py
|
||||
def restore_net():
|
||||
# restore entire net1 to net2
|
||||
net2 = torch.load(\'net.pkl\')
|
||||
@@ -55,7 +55,7 @@ def restore_net():
|
||||
|
||||
这种方式将会提取所有的参数, 然后再放到你的新建网络中.
|
||||
|
||||
```
|
||||
```py
|
||||
def restore_params():
|
||||
# 新建 net3
|
||||
net3 = torch.nn.Sequential(
|
||||
|
||||
Reference in New Issue
Block a user