# 保存和恢复模型 > 原文:[https://tensorflow.google.cn/tutorials/keras/save_and_load](https://tensorflow.google.cn/tutorials/keras/save_and_load) **Note:** 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为, 所以无法保证它们是最准确的,并且反映了最新的 [官方英文文档](https://tensorflow.google.cn/?hl=en)。如果您有改进此翻译的建议, 请提交 pull request 到 [tensorflow/docs](https://github.com/tensorflow/docs) GitHub 仓库。要志愿地撰写或者审核译文,请加入 [docs-zh-cn@tensorflow.org Google Group](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-zh-cn)。 模型可以在训练期间和训练完成后进行保存。这意味着模型可以从任意中断中恢复,并避免耗费比较长的时间在训练上。保存也意味着您可以共享您的模型,而其他人可以通过您的模型来重新创建工作。在发布研究模型和技术时,大多数机器学习从业者分享: * 用于创建模型的代码 * 模型训练的权重 (weight) 和参数 (parameters) 。 共享数据有助于其他人了解模型的工作原理,并使用新数据自行尝试。 注意:小心不受信任的代码——Tensorflow 模型是代码。有关详细信息,请参阅 [安全使用 Tensorflow](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md)。 ### 选项 保存 Tensorflow 的模型有许多方法——具体取决于您使用的 API。本指南使用 [tf.keras](https://tensorflow.google.cn/guide/keras), 一个高级 API 用于在 Tensorflow 中构建和训练模型。有关其他方法的实现,请参阅 TensorFlow [保存和恢复](https://tensorflow.google.cn/guide/saved_model)指南或[保存到 eager](https://tensorflow.google.cn/guide/eager#object-based_saving)。 ## 配置 ### 安装并导入 安装并导入 Tensorflow 和依赖项: ```py pip install -q pyyaml h5py # 以 HDF5 格式保存模型所必须 ``` ```py WARNING: You are using pip version 20.2.2; however, version 20.2.3 is available. You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command. ``` ```py import os import tensorflow as tf from tensorflow import keras print(tf.version.VERSION) ``` ```py 2.3.0 ``` ### 获取示例数据集 要演示如何保存和加载权重,您将使用 [MNIST 数据集](http://yann.lecun.com/exdb/mnist/). 要加快运行速度,请使用前 1000 个示例: ```py (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data() train_labels = train_labels[:1000] test_labels = test_labels[:1000] train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0 test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0 ``` ### 定义模型 首先构建一个简单的序列(sequential)模型: ```py # 定义一个简单的序列模型 def create_model(): model = tf.keras.models.Sequential([ keras.layers.Dense(512, activation='relu', input_shape=(784,)), keras.layers.Dropout(0.2), keras.layers.Dense(10) ]) model.compile(optimizer='adam', loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy']) return model # 创建一个基本的模型实例 model = create_model() # 显示模型的结构 model.summary() ``` ```py Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 512) 401920 _________________________________________________________________ dropout (Dropout) (None, 512) 0 _________________________________________________________________ dense_1 (Dense) (None, 10) 5130 ================================================================= Total params: 407,050 Trainable params: 407,050 Non-trainable params: 0 _________________________________________________________________ ``` ## 在训练期间保存模型(以 checkpoints 形式保存) 您可以使用训练好的模型而无需从头开始重新训练,或在您打断的地方开始训练,以防止训练过程没有保存。 [`tf.keras.callbacks.ModelCheckpoint`](https://tensorflow.google.cn/api_docs/python/tf/keras/callbacks/ModelCheckpoint) 允许在训练的*过程中*和*结束时*回调保存的模型。 ### Checkpoint 回调用法 创建一个只在训练期间保存权重的 [`tf.keras.callbacks.ModelCheckpoint`](https://tensorflow.google.cn/api_docs/python/tf/keras/callbacks/ModelCheckpoint) 回调: ```py checkpoint_path = "training_1/cp.ckpt" checkpoint_dir = os.path.dirname(checkpoint_path) # 创建一个保存模型权重的回调 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1) # 使用新的回调训练模型 model.fit(train_images, train_labels, epochs=10, validation_data=(test_images,test_labels), callbacks=[cp_callback]) # 通过回调训练 # 这可能会生成与保存优化程序状态相关的警告。 # 这些警告(以及整个笔记本中的类似警告) # 是防止过时使用,可以忽略。 ``` ```py Epoch 1/10 29/32 [==========================>...] - ETA: 0s - loss: 1.1844 - accuracy: 0.6595 Epoch 00001: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 8ms/step - loss: 1.1300 - accuracy: 0.6770 - val_loss: 0.7189 - val_accuracy: 0.7780 Epoch 2/10 30/32 [===========================>..] - ETA: 0s - loss: 0.4232 - accuracy: 0.8792 Epoch 00002: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.4216 - accuracy: 0.8800 - val_loss: 0.5160 - val_accuracy: 0.8470 Epoch 3/10 29/32 [==========================>...] - ETA: 0s - loss: 0.2964 - accuracy: 0.9149 Epoch 00003: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 4ms/step - loss: 0.2988 - accuracy: 0.9170 - val_loss: 0.4753 - val_accuracy: 0.8560 Epoch 4/10 29/32 [==========================>...] - ETA: 0s - loss: 0.2057 - accuracy: 0.9494 Epoch 00004: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 4ms/step - loss: 0.2086 - accuracy: 0.9500 - val_loss: 0.4375 - val_accuracy: 0.8600 Epoch 5/10 29/32 [==========================>...] - ETA: 0s - loss: 0.1512 - accuracy: 0.9666 Epoch 00005: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 4ms/step - loss: 0.1488 - accuracy: 0.9680 - val_loss: 0.4275 - val_accuracy: 0.8660 Epoch 6/10 30/32 [===========================>..] - ETA: 0s - loss: 0.1130 - accuracy: 0.9823 Epoch 00006: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 4ms/step - loss: 0.1134 - accuracy: 0.9820 - val_loss: 0.4309 - val_accuracy: 0.8630 Epoch 7/10 29/32 [==========================>...] - ETA: 0s - loss: 0.0829 - accuracy: 0.9925 Epoch 00007: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 4ms/step - loss: 0.0838 - accuracy: 0.9920 - val_loss: 0.4079 - val_accuracy: 0.8680 Epoch 8/10 29/32 [==========================>...] - ETA: 0s - loss: 0.0624 - accuracy: 0.9946 Epoch 00008: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 4ms/step - loss: 0.0627 - accuracy: 0.9950 - val_loss: 0.4176 - val_accuracy: 0.8690 Epoch 9/10 29/32 [==========================>...] - ETA: 0s - loss: 0.0520 - accuracy: 0.9946 Epoch 00009: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 4ms/step - loss: 0.0508 - accuracy: 0.9950 - val_loss: 0.4600 - val_accuracy: 0.8450 Epoch 10/10 29/32 [==========================>...] - ETA: 0s - loss: 0.0462 - accuracy: 0.9968 Epoch 00010: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 4ms/step - loss: 0.0459 - accuracy: 0.9970 - val_loss: 0.4378 - val_accuracy: 0.8660 ``` 这将创建一个 TensorFlow checkpoint 文件集合,这些文件在每个 epoch 结束时更新: ```py ls {checkpoint_dir} ``` ```py checkpoint cp.ckpt.data-00000-of-00001 cp.ckpt.index ``` 创建一个新的未经训练的模型。仅恢复模型的权重时,必须具有与原始模型具有相同网络结构的模型。由于模型具有相同的结构,您可以共享权重,尽管它是模型的不同*实例*。 现在重建一个新的未经训练的模型,并在测试集上进行评估。未经训练的模型将在机会水平(chance levels)上执行(准确度约为 10%): ```py # 创建一个基本模型实例 model = create_model() # 评估模型 loss, acc = model.evaluate(test_images, test_labels, verbose=2) print("Untrained model, accuracy: {:5.2f}%".format(100*acc)) ``` ```py 32/32 - 0s - loss: 2.3734 - accuracy: 0.0990 Untrained model, accuracy: 9.90% ``` 然后从 checkpoint 加载权重并重新评估: ```py # 加载权重 model.load_weights(checkpoint_path) # 重新评估模型 loss,acc = model.evaluate(test_images, test_labels, verbose=2) print("Restored model, accuracy: {:5.2f}%".format(100*acc)) ``` ```py 32/32 - 0s - loss: 0.4378 - accuracy: 0.8660 Restored model, accuracy: 86.60% ``` ### checkpoint 回调选项 回调提供了几个选项,为 checkpoint 提供唯一名称并调整 checkpoint 频率。 训练一个新模型,每五个 epochs 保存一次唯一命名的 checkpoint : ```py # 在文件名中包含 epoch (使用 `str.format`) checkpoint_path = "training_2/cp-{epoch:04d}.ckpt" checkpoint_dir = os.path.dirname(checkpoint_path) # 创建一个回调,每 5 个 epochs 保存模型的权重 cp_callback = tf.keras.callbacks.ModelCheckpoint( filepath=checkpoint_path, verbose=1, save_weights_only=True, period=5) # 创建一个新的模型实例 model = create_model() # 使用 `checkpoint_path` 格式保存权重 model.save_weights(checkpoint_path.format(epoch=0)) # 使用新的回调训练模型 model.fit(train_images, train_labels, epochs=50, callbacks=[cp_callback], validation_data=(test_images,test_labels), verbose=0) ``` ```py WARNING:tensorflow:`period` argument is deprecated. Please use `save_freq` to specify the frequency in number of batches seen. WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details. Epoch 00005: saving model to training_2/cp-0005.ckpt Epoch 00010: saving model to training_2/cp-0010.ckpt Epoch 00015: saving model to training_2/cp-0015.ckpt Epoch 00020: saving model to training_2/cp-0020.ckpt Epoch 00025: saving model to training_2/cp-0025.ckpt Epoch 00030: saving model to training_2/cp-0030.ckpt Epoch 00035: saving model to training_2/cp-0035.ckpt Epoch 00040: saving model to training_2/cp-0040.ckpt Epoch 00045: saving model to training_2/cp-0045.ckpt Epoch 00050: saving model to training_2/cp-0050.ckpt ``` 现在查看生成的 checkpoint 并选择最新的 checkpoint : ```py ls {checkpoint_dir} ``` ```py checkpoint cp-0025.ckpt.index cp-0000.ckpt.data-00000-of-00001 cp-0030.ckpt.data-00000-of-00001 cp-0000.ckpt.index cp-0030.ckpt.index cp-0005.ckpt.data-00000-of-00001 cp-0035.ckpt.data-00000-of-00001 cp-0005.ckpt.index cp-0035.ckpt.index cp-0010.ckpt.data-00000-of-00001 cp-0040.ckpt.data-00000-of-00001 cp-0010.ckpt.index cp-0040.ckpt.index cp-0015.ckpt.data-00000-of-00001 cp-0045.ckpt.data-00000-of-00001 cp-0015.ckpt.index cp-0045.ckpt.index cp-0020.ckpt.data-00000-of-00001 cp-0050.ckpt.data-00000-of-00001 cp-0020.ckpt.index cp-0050.ckpt.index cp-0025.ckpt.data-00000-of-00001 ``` ```py latest = tf.train.latest_checkpoint(checkpoint_dir) latest ``` ```py 'training_2/cp-0050.ckpt' ``` 注意: 默认的 tensorflow 格式仅保存最近的 5 个 checkpoint 。 如果要进行测试,请重置模型并加载最新的 checkpoint : ```py # 创建一个新的模型实例 model = create_model() # 加载以前保存的权重 model.load_weights(latest) # 重新评估模型 loss, acc = model.evaluate(test_images, test_labels, verbose=2) print("Restored model, accuracy: {:5.2f}%".format(100*acc)) ``` ```py 32/32 - 0s - loss: 0.4836 - accuracy: 0.8750 Restored model, accuracy: 87.50% ``` ## 这些文件是什么? 上述代码将权重存储到 [checkpoint](https://tensorflow.google.cn/guide/saved_model#save_and_restore_variables)—— 格式化文件的集合中,这些文件仅包含二进制格式的训练权重。 Checkpoints 包含: * 一个或多个包含模型权重的分片。 * 索引文件,指示哪些权重存储在哪个分片中。 如果你只在一台机器上训练一个模型,你将有一个带有后缀的碎片: `.data-00000-of-00001` ## 手动保存权重 您将了解如何将权重加载到模型中。使用 [`Model.save_weights`](https://tensorflow.google.cn/api_docs/python/tf/keras/Model#save_weights) 方法手动保存它们同样简单。默认情况下, [`tf.keras`](https://tensorflow.google.cn/api_docs/python/tf/keras) 和 `save_weights` 特别使用 TensorFlow [checkpoints](https://tensorflow.google.cn/guide/keras/checkpoints) 格式 `.ckpt` 扩展名和 ( 保存在 [HDF5](https://js.tensorflow.org/tutorials/import-keras.html) 扩展名为 `.h5` [保存并序列化模型](https://tensorflow.google.cn/guide/keras/save_and_serialize#weights_only_saving_in_savedmodel_format) ): ```py # 保存权重 model.save_weights('./checkpoints/my_checkpoint') # 创建模型实例 model = create_model() # 恢复权重 model.load_weights('./checkpoints/my_checkpoint') # 评估模型 loss,acc = model.evaluate(test_images, test_labels, verbose=2) print("Restored model, accuracy: {:5.2f}%".format(100*acc)) ``` ```py WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details. 32/32 - 0s - loss: 0.4836 - accuracy: 0.8750 Restored model, accuracy: 87.50% ``` ## 保存整个模型 调用 [`model.save`](https://tensorflow.google.cn/api_docs/python/tf/keras/Model#save) 将保存模型的结构,权重和训练配置保存在单个文件/文件夹中。这可以让您导出模型,以便在不访问原始 Python 代码*的情况下使用它。因为优化器状态(optimizer-state)已经恢复,您可以从中断的位置恢复训练。 整个模型可以以两种不同的文件格式(`SavedModel` 和 `HDF5`)进行保存。需要注意的是 TensorFlow 的 `SavedModel` 格式是 TF2.x. 中的默认文件格式。但是,模型仍可以以 `HDF5` 格式保存。下面介绍了以两种文件格式保存整个模型的更多详细信息。 保存完整模型会非常有用——您可以在 TensorFlow.js([Saved Model](https://tensorflow.google.cn/js/tutorials/conversion/import_saved_model), [HDF5](https://tensorflow.google.cn/js/tutorials/conversion/import_keras))加载它们,然后在 web 浏览器中训练和运行它们,或者使用 TensorFlow Lite 将它们转换为在移动设备上运行([Saved Model](https://tensorflow.google.cn/lite/convert/python_api#converting_a_savedmodel_), [HDF5](https://tensorflow.google.cn/lite/convert/python_api#converting_a_keras_model_)) *自定义对象(例如,子类化模型或层)在保存和加载时需要特别注意。请参阅下面的**保存自定义对象**部分 ### SavedModel 格式 SavedModel 格式是序列化模型的另一种方法。以这种格式保存的模型,可以使用 [`tf.keras.models.load_model`](https://tensorflow.google.cn/api_docs/python/tf/keras/models/load_model) 还原,并且模型与 TensorFlow Serving 兼容。[SavedModel 指南](https://tensorflow.google.cn/guide/saved_model)详细介绍了如何提供/检查 SavedModel。以下部分说明了保存和还原模型的步骤。 ```py # 创建并训练一个新的模型实例。 model = create_model() model.fit(train_images, train_labels, epochs=5) # 将整个模型另存为 SavedModel。 !mkdir -p saved_model model.save('saved_model/my_model') ``` ```py Epoch 1/5 32/32 [==============================] - 0s 2ms/step - loss: 1.1705 - accuracy: 0.6690 Epoch 2/5 32/32 [==============================] - 0s 2ms/step - loss: 0.4326 - accuracy: 0.8780 Epoch 3/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2910 - accuracy: 0.9190 Epoch 4/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2045 - accuracy: 0.9520 Epoch 5/5 32/32 [==============================] - 0s 2ms/step - loss: 0.1538 - accuracy: 0.9650 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. INFO:tensorflow:Assets written to: saved_model/my_model/assets ``` SavedModel 格式是一个包含 protobuf 二进制文件和 Tensorflow 检查点(checkpoint)的目录。检查保存的模型目录: ```py # my_model 文件夹 ls saved_model # 包含一个 assets 文件夹,saved_model.pb,和变量文件夹。 ls saved_model/my_model ``` ```py my_model assets saved_model.pb variables ``` 从保存的模型重新加载新的 Keras 模型: ```py new_model = tf.keras.models.load_model('saved_model/my_model') # 检查其架构 new_model.summary() ``` ```py Model: "sequential_5" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_10 (Dense) (None, 512) 401920 _________________________________________________________________ dropout_5 (Dropout) (None, 512) 0 _________________________________________________________________ dense_11 (Dense) (None, 10) 5130 ================================================================= Total params: 407,050 Trainable params: 407,050 Non-trainable params: 0 _________________________________________________________________ ``` 还原的模型使用与原始模型相同的参数进行编译。 尝试使用加载的模型运行评估和预测: ```py # 评估还原的模型 loss, acc = new_model.evaluate(test_images, test_labels, verbose=2) print('Restored model, accuracy: {:5.2f}%'.format(100*acc)) print(new_model.predict(test_images).shape) ``` ```py 32/32 - 0s - loss: 0.4630 - accuracy: 0.0890 Restored model, accuracy: 8.90% (1000, 10) ``` ### HDF5 格式 Keras 使用 [HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) 标准提供了一种基本的保存格式。 ```py # 创建并训练一个新的模型实例 model = create_model() model.fit(train_images, train_labels, epochs=5) # 将整个模型保存为 HDF5 文件。 # '.h5' 扩展名指示应将模型保存到 HDF5。 model.save('my_model.h5') ``` ```py Epoch 1/5 32/32 [==============================] - 0s 2ms/step - loss: 1.1465 - accuracy: 0.6560 Epoch 2/5 32/32 [==============================] - 0s 2ms/step - loss: 0.4152 - accuracy: 0.8850 Epoch 3/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2801 - accuracy: 0.9280 Epoch 4/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2108 - accuracy: 0.9480 Epoch 5/5 32/32 [==============================] - 0s 2ms/step - loss: 0.1520 - accuracy: 0.9660 ``` 现在,从该文件重新创建模型: ```py # 重新创建完全相同的模型,包括其权重和优化程序 new_model = tf.keras.models.load_model('my_model.h5') # 显示网络结构 new_model.summary() ``` ```py Model: "sequential_6" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_12 (Dense) (None, 512) 401920 _________________________________________________________________ dropout_6 (Dropout) (None, 512) 0 _________________________________________________________________ dense_13 (Dense) (None, 10) 5130 ================================================================= Total params: 407,050 Trainable params: 407,050 Non-trainable params: 0 _________________________________________________________________ ``` 检查其准确率(accuracy): ```py loss, acc = new_model.evaluate(test_images, test_labels, verbose=2) print('Restored model, accuracy: {:5.2f}%'.format(100*acc)) ``` ```py 32/32 - 0s - loss: 0.4639 - accuracy: 0.0840 Restored model, accuracy: 8.40% ``` Keras 通过检查网络结构来保存模型。这项技术可以保存一切: * 权重值 * 模型的架构 * 模型的训练配置(您传递给编译的内容) * 优化器及其状态(如果有的话)(这使您可以在中断的地方重新开始训练) Keras 无法保存 `v1.x` 优化器(来自 [`tf.compat.v1.train`](https://tensorflow.google.cn/api_docs/python/tf/compat/v1/train)),因为它们与检查点不兼容。对于 v1.x 优化器,您需要在加载-失去优化器的状态后,重新编译模型。 ### 保存自定义对象 如果使用的是 SavedModel 格式,则可以跳过此部分。HDF5 和 SavedModel 之间的主要区别在于,HDF5 使用对象配置保存模型结构,而 SavedModel 保存执行图。因此,SavedModel 能够保存自定义对象,例如子类化模型和自定义层,而无需原始代码。 要将自定义对象保存到 HDF5,必须执行以下操作: 1. 在对象中定义一个 `get_config` 方法,以及可选的 `from_config` 类方法。 * `get_config(self)` 返回重新创建对象所需的参数的 JSON 可序列化字典。 * `from_config(cls, config)` 使用从 get_config 返回的 config 来创建一个新对象。默认情况下,此函数将使用 config 作为初始化 kwargs(`return cls(**config)`)。 2. 加载模型时,将对象传递给 `custom_objects` 参数。参数必须是将字符串类名称映射到 Python 类的字典。例如,`tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})` 有关自定义对象和 `get_config` 的示例,请参见[从头开始编写层和模型](https://tensorflow.google.cn/guide/keras/custom_layers_and_models)教程。 ```py # MIT License # # Copyright (c) 2017 François Chollet # # Permission is hereby granted, free of charge, to any person obtaining a # copy of this software and associated documentation files (the "Software"), # to deal in the Software without restriction, including without limitation # the rights to use, copy, modify, merge, publish, distribute, sublicense, # and/or sell copies of the Software, and to permit persons to whom the # Software is furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. ```