Files
notes_estom/Tensorflow/TensorFlow2.0/011.md
yinkanglong_lab 68c2dbc3ac apacheCNml&dl
2021-03-20 16:02:39 +08:00

610 lines
25 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 保存和恢复模型
> 原文:[https://tensorflow.google.cn/tutorials/keras/save_and_load](https://tensorflow.google.cn/tutorials/keras/save_and_load)
**Note:** 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为, 所以无法保证它们是最准确的,并且反映了最新的 [官方英文文档](https://tensorflow.google.cn/?hl=en)。如果您有改进此翻译的建议, 请提交 pull request 到 [tensorflow/docs](https://github.com/tensorflow/docs) GitHub 仓库。要志愿地撰写或者审核译文,请加入 [docs-zh-cn@tensorflow.org Google Group](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-zh-cn)。
模型可以在训练期间和训练完成后进行保存。这意味着模型可以从任意中断中恢复,并避免耗费比较长的时间在训练上。保存也意味着您可以共享您的模型,而其他人可以通过您的模型来重新创建工作。在发布研究模型和技术时,大多数机器学习从业者分享:
* 用于创建模型的代码
* 模型训练的权重 (weight) 和参数 (parameters) 。
共享数据有助于其他人了解模型的工作原理,并使用新数据自行尝试。
注意小心不受信任的代码——Tensorflow 模型是代码。有关详细信息,请参阅 [安全使用 Tensorflow](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md)。
### 选项
保存 Tensorflow 的模型有许多方法——具体取决于您使用的 API。本指南使用 [tf.keras](https://tensorflow.google.cn/guide/keras) 一个高级 API 用于在 Tensorflow 中构建和训练模型。有关其他方法的实现,请参阅 TensorFlow [保存和恢复](https://tensorflow.google.cn/guide/saved_model)指南或[保存到 eager](https://tensorflow.google.cn/guide/eager#object-based_saving)。
## 配置
### 安装并导入
安装并导入 Tensorflow 和依赖项:
```py
pip install -q pyyaml h5py # 以 HDF5 格式保存模型所必须
```
```py
WARNING: You are using pip version 20.2.2; however, version 20.2.3 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.
```
```py
import os
import tensorflow as tf
from tensorflow import keras
print(tf.version.VERSION)
```
```py
2.3.0
```
### 获取示例数据集
要演示如何保存和加载权重,您将使用 [MNIST 数据集](http://yann.lecun.com/exdb/mnist/). 要加快运行速度,请使用前 1000 个示例:
```py
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
```
### 定义模型
首先构建一个简单的序列sequential模型
```py
# 定义一个简单的序列模型
def create_model():
model = tf.keras.models.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10)
])
model.compile(optimizer='adam',
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
return model
# 创建一个基本的模型实例
model = create_model()
# 显示模型的结构
model.summary()
```
```py
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 512) 401920
_________________________________________________________________
dropout (Dropout) (None, 512) 0
_________________________________________________________________
dense_1 (Dense) (None, 10) 5130
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________
```
## 在训练期间保存模型(以 checkpoints 形式保存)
您可以使用训练好的模型而无需从头开始重新训练,或在您打断的地方开始训练,以防止训练过程没有保存。 [`tf.keras.callbacks.ModelCheckpoint`](https://tensorflow.google.cn/api_docs/python/tf/keras/callbacks/ModelCheckpoint) 允许在训练的*过程中*和*结束时*回调保存的模型。
### Checkpoint 回调用法
创建一个只在训练期间保存权重的 [`tf.keras.callbacks.ModelCheckpoint`](https://tensorflow.google.cn/api_docs/python/tf/keras/callbacks/ModelCheckpoint) 回调:
```py
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# 创建一个保存模型权重的回调
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
verbose=1)
# 使用新的回调训练模型
model.fit(train_images,
train_labels,
epochs=10,
validation_data=(test_images,test_labels),
callbacks=[cp_callback]) # 通过回调训练
# 这可能会生成与保存优化程序状态相关的警告。
# 这些警告(以及整个笔记本中的类似警告)
# 是防止过时使用,可以忽略。
```
```py
Epoch 1/10
29/32 [==========================>...] - ETA: 0s - loss: 1.1844 - accuracy: 0.6595
Epoch 00001: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 8ms/step - loss: 1.1300 - accuracy: 0.6770 - val_loss: 0.7189 - val_accuracy: 0.7780
Epoch 2/10
30/32 [===========================>..] - ETA: 0s - loss: 0.4232 - accuracy: 0.8792
Epoch 00002: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.4216 - accuracy: 0.8800 - val_loss: 0.5160 - val_accuracy: 0.8470
Epoch 3/10
29/32 [==========================>...] - ETA: 0s - loss: 0.2964 - accuracy: 0.9149
Epoch 00003: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 4ms/step - loss: 0.2988 - accuracy: 0.9170 - val_loss: 0.4753 - val_accuracy: 0.8560
Epoch 4/10
29/32 [==========================>...] - ETA: 0s - loss: 0.2057 - accuracy: 0.9494
Epoch 00004: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 4ms/step - loss: 0.2086 - accuracy: 0.9500 - val_loss: 0.4375 - val_accuracy: 0.8600
Epoch 5/10
29/32 [==========================>...] - ETA: 0s - loss: 0.1512 - accuracy: 0.9666
Epoch 00005: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 4ms/step - loss: 0.1488 - accuracy: 0.9680 - val_loss: 0.4275 - val_accuracy: 0.8660
Epoch 6/10
30/32 [===========================>..] - ETA: 0s - loss: 0.1130 - accuracy: 0.9823
Epoch 00006: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 4ms/step - loss: 0.1134 - accuracy: 0.9820 - val_loss: 0.4309 - val_accuracy: 0.8630
Epoch 7/10
29/32 [==========================>...] - ETA: 0s - loss: 0.0829 - accuracy: 0.9925
Epoch 00007: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 4ms/step - loss: 0.0838 - accuracy: 0.9920 - val_loss: 0.4079 - val_accuracy: 0.8680
Epoch 8/10
29/32 [==========================>...] - ETA: 0s - loss: 0.0624 - accuracy: 0.9946
Epoch 00008: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 4ms/step - loss: 0.0627 - accuracy: 0.9950 - val_loss: 0.4176 - val_accuracy: 0.8690
Epoch 9/10
29/32 [==========================>...] - ETA: 0s - loss: 0.0520 - accuracy: 0.9946
Epoch 00009: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 4ms/step - loss: 0.0508 - accuracy: 0.9950 - val_loss: 0.4600 - val_accuracy: 0.8450
Epoch 10/10
29/32 [==========================>...] - ETA: 0s - loss: 0.0462 - accuracy: 0.9968
Epoch 00010: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 4ms/step - loss: 0.0459 - accuracy: 0.9970 - val_loss: 0.4378 - val_accuracy: 0.8660
<tensorflow.python.keras.callbacks.History at 0x7fe7b286b710>
```
这将创建一个 TensorFlow checkpoint 文件集合,这些文件在每个 epoch 结束时更新:
```py
ls {checkpoint_dir}
```
```py
checkpoint cp.ckpt.data-00000-of-00001 cp.ckpt.index
```
创建一个新的未经训练的模型。仅恢复模型的权重时,必须具有与原始模型具有相同网络结构的模型。由于模型具有相同的结构,您可以共享权重,尽管它是模型的不同*实例*。 现在重建一个新的未经训练的模型并在测试集上进行评估。未经训练的模型将在机会水平chance levels上执行准确度约为 10
```py
# 创建一个基本模型实例
model = create_model()
# 评估模型
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc))
```
```py
32/32 - 0s - loss: 2.3734 - accuracy: 0.0990
Untrained model, accuracy: 9.90%
```
然后从 checkpoint 加载权重并重新评估:
```py
# 加载权重
model.load_weights(checkpoint_path)
# 重新评估模型
loss,acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
```
```py
32/32 - 0s - loss: 0.4378 - accuracy: 0.8660
Restored model, accuracy: 86.60%
```
### checkpoint 回调选项
回调提供了几个选项,为 checkpoint 提供唯一名称并调整 checkpoint 频率。
训练一个新模型,每五个 epochs 保存一次唯一命名的 checkpoint
```py
# 在文件名中包含 epoch (使用 `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# 创建一个回调,每 5 个 epochs 保存模型的权重
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
verbose=1,
save_weights_only=True,
period=5)
# 创建一个新的模型实例
model = create_model()
# 使用 `checkpoint_path` 格式保存权重
model.save_weights(checkpoint_path.format(epoch=0))
# 使用新的回调训练模型
model.fit(train_images,
train_labels,
epochs=50,
callbacks=[cp_callback],
validation_data=(test_images,test_labels),
verbose=0)
```
```py
WARNING:tensorflow:`period` argument is deprecated. Please use `save_freq` to specify the frequency in number of batches seen.
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
Epoch 00005: saving model to training_2/cp-0005.ckpt
Epoch 00010: saving model to training_2/cp-0010.ckpt
Epoch 00015: saving model to training_2/cp-0015.ckpt
Epoch 00020: saving model to training_2/cp-0020.ckpt
Epoch 00025: saving model to training_2/cp-0025.ckpt
Epoch 00030: saving model to training_2/cp-0030.ckpt
Epoch 00035: saving model to training_2/cp-0035.ckpt
Epoch 00040: saving model to training_2/cp-0040.ckpt
Epoch 00045: saving model to training_2/cp-0045.ckpt
Epoch 00050: saving model to training_2/cp-0050.ckpt
<tensorflow.python.keras.callbacks.History at 0x7fe8021c76a0>
```
现在查看生成的 checkpoint 并选择最新的 checkpoint
```py
ls {checkpoint_dir}
```
```py
checkpoint cp-0025.ckpt.index
cp-0000.ckpt.data-00000-of-00001 cp-0030.ckpt.data-00000-of-00001
cp-0000.ckpt.index cp-0030.ckpt.index
cp-0005.ckpt.data-00000-of-00001 cp-0035.ckpt.data-00000-of-00001
cp-0005.ckpt.index cp-0035.ckpt.index
cp-0010.ckpt.data-00000-of-00001 cp-0040.ckpt.data-00000-of-00001
cp-0010.ckpt.index cp-0040.ckpt.index
cp-0015.ckpt.data-00000-of-00001 cp-0045.ckpt.data-00000-of-00001
cp-0015.ckpt.index cp-0045.ckpt.index
cp-0020.ckpt.data-00000-of-00001 cp-0050.ckpt.data-00000-of-00001
cp-0020.ckpt.index cp-0050.ckpt.index
cp-0025.ckpt.data-00000-of-00001
```
```py
latest = tf.train.latest_checkpoint(checkpoint_dir)
latest
```
```py
'training_2/cp-0050.ckpt'
```
注意: 默认的 tensorflow 格式仅保存最近的 5 个 checkpoint 。
如果要进行测试,请重置模型并加载最新的 checkpoint
```py
# 创建一个新的模型实例
model = create_model()
# 加载以前保存的权重
model.load_weights(latest)
# 重新评估模型
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
```
```py
32/32 - 0s - loss: 0.4836 - accuracy: 0.8750
Restored model, accuracy: 87.50%
```
## 这些文件是什么?
上述代码将权重存储到 [checkpoint](https://tensorflow.google.cn/guide/saved_model#save_and_restore_variables)—— 格式化文件的集合中,这些文件仅包含二进制格式的训练权重。 Checkpoints 包含:
* 一个或多个包含模型权重的分片。
* 索引文件,指示哪些权重存储在哪个分片中。
如果你只在一台机器上训练一个模型,你将有一个带有后缀的碎片: `.data-00000-of-00001`
## 手动保存权重
您将了解如何将权重加载到模型中。使用 [`Model.save_weights`](https://tensorflow.google.cn/api_docs/python/tf/keras/Model#save_weights) 方法手动保存它们同样简单。默认情况下, [`tf.keras`](https://tensorflow.google.cn/api_docs/python/tf/keras) 和 `save_weights` 特别使用 TensorFlow [checkpoints](https://tensorflow.google.cn/guide/keras/checkpoints) 格式 `.ckpt` 扩展名和 ( 保存在 [HDF5](https://js.tensorflow.org/tutorials/import-keras.html) 扩展名为 `.h5` [保存并序列化模型](https://tensorflow.google.cn/guide/keras/save_and_serialize#weights_only_saving_in_savedmodel_format) )
```py
# 保存权重
model.save_weights('./checkpoints/my_checkpoint')
# 创建模型实例
model = create_model()
# 恢复权重
model.load_weights('./checkpoints/my_checkpoint')
# 评估模型
loss,acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
```
```py
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
32/32 - 0s - loss: 0.4836 - accuracy: 0.8750
Restored model, accuracy: 87.50%
```
## 保存整个模型
调用 [`model.save`](https://tensorflow.google.cn/api_docs/python/tf/keras/Model#save) 将保存模型的结构,权重和训练配置保存在单个文件/文件夹中。这可以让您导出模型,以便在不访问原始 Python 代码*的情况下使用它。因为优化器状态optimizer-state已经恢复您可以从中断的位置恢复训练。
整个模型可以以两种不同的文件格式(`SavedModel``HDF5`)进行保存。需要注意的是 TensorFlow 的 `SavedModel` 格式是 TF2.x. 中的默认文件格式。但是,模型仍可以以 `HDF5` 格式保存。下面介绍了以两种文件格式保存整个模型的更多详细信息。
保存完整模型会非常有用——您可以在 TensorFlow.js[Saved Model](https://tensorflow.google.cn/js/tutorials/conversion/import_saved_model), [HDF5](https://tensorflow.google.cn/js/tutorials/conversion/import_keras))加载它们,然后在 web 浏览器中训练和运行它们,或者使用 TensorFlow Lite 将它们转换为在移动设备上运行([Saved Model](https://tensorflow.google.cn/lite/convert/python_api#converting_a_savedmodel_), [HDF5](https://tensorflow.google.cn/lite/convert/python_api#converting_a_keras_model_)
*自定义对象(例如,子类化模型或层)在保存和加载时需要特别注意。请参阅下面的**保存自定义对象**部分
### SavedModel 格式
SavedModel 格式是序列化模型的另一种方法。以这种格式保存的模型,可以使用 [`tf.keras.models.load_model`](https://tensorflow.google.cn/api_docs/python/tf/keras/models/load_model) 还原,并且模型与 TensorFlow Serving 兼容。[SavedModel 指南](https://tensorflow.google.cn/guide/saved_model)详细介绍了如何提供/检查 SavedModel。以下部分说明了保存和还原模型的步骤。
```py
# 创建并训练一个新的模型实例。
model = create_model()
model.fit(train_images, train_labels, epochs=5)
# 将整个模型另存为 SavedModel。
!mkdir -p saved_model
model.save('saved_model/my_model')
```
```py
Epoch 1/5
32/32 [==============================] - 0s 2ms/step - loss: 1.1705 - accuracy: 0.6690
Epoch 2/5
32/32 [==============================] - 0s 2ms/step - loss: 0.4326 - accuracy: 0.8780
Epoch 3/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2910 - accuracy: 0.9190
Epoch 4/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2045 - accuracy: 0.9520
Epoch 5/5
32/32 [==============================] - 0s 2ms/step - loss: 0.1538 - accuracy: 0.9650
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: saved_model/my_model/assets
```
SavedModel 格式是一个包含 protobuf 二进制文件和 Tensorflow 检查点checkpoint的目录。检查保存的模型目录
```py
# my_model 文件夹
ls saved_model
# 包含一个 assets 文件夹saved_model.pb和变量文件夹。
ls saved_model/my_model
```
```py
my_model
assets saved_model.pb variables
```
从保存的模型重新加载新的 Keras 模型:
```py
new_model = tf.keras.models.load_model('saved_model/my_model')
# 检查其架构
new_model.summary()
```
```py
Model: "sequential_5"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_10 (Dense) (None, 512) 401920
_________________________________________________________________
dropout_5 (Dropout) (None, 512) 0
_________________________________________________________________
dense_11 (Dense) (None, 10) 5130
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________
```
还原的模型使用与原始模型相同的参数进行编译。 尝试使用加载的模型运行评估和预测:
```py
# 评估还原的模型
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100*acc))
print(new_model.predict(test_images).shape)
```
```py
32/32 - 0s - loss: 0.4630 - accuracy: 0.0890
Restored model, accuracy: 8.90%
(1000, 10)
```
### HDF5 格式
Keras 使用 [HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) 标准提供了一种基本的保存格式。
```py
# 创建并训练一个新的模型实例
model = create_model()
model.fit(train_images, train_labels, epochs=5)
# 将整个模型保存为 HDF5 文件。
# '.h5' 扩展名指示应将模型保存到 HDF5。
model.save('my_model.h5')
```
```py
Epoch 1/5
32/32 [==============================] - 0s 2ms/step - loss: 1.1465 - accuracy: 0.6560
Epoch 2/5
32/32 [==============================] - 0s 2ms/step - loss: 0.4152 - accuracy: 0.8850
Epoch 3/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2801 - accuracy: 0.9280
Epoch 4/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2108 - accuracy: 0.9480
Epoch 5/5
32/32 [==============================] - 0s 2ms/step - loss: 0.1520 - accuracy: 0.9660
```
现在,从该文件重新创建模型:
```py
# 重新创建完全相同的模型,包括其权重和优化程序
new_model = tf.keras.models.load_model('my_model.h5')
# 显示网络结构
new_model.summary()
```
```py
Model: "sequential_6"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_12 (Dense) (None, 512) 401920
_________________________________________________________________
dropout_6 (Dropout) (None, 512) 0
_________________________________________________________________
dense_13 (Dense) (None, 10) 5130
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________
```
检查其准确率accuracy
```py
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100*acc))
```
```py
32/32 - 0s - loss: 0.4639 - accuracy: 0.0840
Restored model, accuracy: 8.40%
```
Keras 通过检查网络结构来保存模型。这项技术可以保存一切:
* 权重值
* 模型的架构
* 模型的训练配置(您传递给编译的内容)
* 优化器及其状态(如果有的话)(这使您可以在中断的地方重新开始训练)
Keras 无法保存 `v1.x` 优化器(来自 [`tf.compat.v1.train`](https://tensorflow.google.cn/api_docs/python/tf/compat/v1/train)),因为它们与检查点不兼容。对于 v1.x 优化器,您需要在加载-失去优化器的状态后,重新编译模型。
### 保存自定义对象
如果使用的是 SavedModel 格式则可以跳过此部分。HDF5 和 SavedModel 之间的主要区别在于HDF5 使用对象配置保存模型结构,而 SavedModel 保存执行图。因此SavedModel 能够保存自定义对象,例如子类化模型和自定义层,而无需原始代码。
要将自定义对象保存到 HDF5必须执行以下操作:
1. 在对象中定义一个 `get_config` 方法,以及可选的 `from_config` 类方法。
* `get_config(self)` 返回重新创建对象所需的参数的 JSON 可序列化字典。
* `from_config(cls, config)` 使用从 get_config 返回的 config 来创建一个新对象。默认情况下,此函数将使用 config 作为初始化 kwargs`return cls(**config)`)。
2. 加载模型时,将对象传递给 `custom_objects` 参数。参数必须是将字符串类名称映射到 Python 类的字典。例如,`tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})`
有关自定义对象和 `get_config` 的示例,请参见[从头开始编写层和模型](https://tensorflow.google.cn/guide/keras/custom_layers_and_models)教程。
```py
# MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
```