# 使用分布策略保存和加载模型 > 原文:[https://tensorflow.google.cn/tutorials/distribute/save_and_load](https://tensorflow.google.cn/tutorials/distribute/save_and_load) ## 概述 在训练期间一般需要保存和加载模型。有两组用于保存和加载 Keras 模型的 API:高级 API 和低级 API。本教程演示了在使用 [`tf.distribute.Strategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/Strategy) 时如何使用 SavedModel API。要了解 SavedModel 和序列化的相关概况,请参阅[保存的模型指南](https://tensorflow.google.cn/guide/saved_model)和 [Keras 模型序列化指南](https://tensorflow.google.cn/guide/keras/save_and_serialize)。让我们从一个简单的示例开始: 导入依赖项: ```py import tensorflow_datasets as tfds import tensorflow as tf tfds.disable_progress_bar() ``` 使用 [`tf.distribute.Strategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/Strategy) 准备数据和模型: ```py mirrored_strategy = tf.distribute.MirroredStrategy() def get_data(): datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True) mnist_train, mnist_test = datasets['train'], datasets['test'] BUFFER_SIZE = 10000 BATCH_SIZE_PER_REPLICA = 64 BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync def scale(image, label): image = tf.cast(image, tf.float32) image /= 255 return image, label train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE) eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE) return train_dataset, eval_dataset def get_model(): with mirrored_strategy.scope(): model = tf.keras.Sequential([ tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Flatten(), tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(10) ]) model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy']) return model ``` ```py INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) ``` 训练模型: ```py model = get_model() train_dataset, eval_dataset = get_data() model.fit(train_dataset, epochs=2) ``` ```py Epoch 1/2 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Iterator.get_next_as_optional()` instead. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Iterator.get_next_as_optional()` instead. INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). 938/938 [==============================] - 4s 5ms/step - loss: 0.1971 - accuracy: 0.9421 INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). Epoch 2/2 938/938 [==============================] - 3s 3ms/step - loss: 0.0662 - accuracy: 0.9801 ``` ## 保存和加载模型 现在,您已经有一个简单的模型可供使用,让我们了解一下如何保存/加载 API。有两组可用的 API: * 高级 Keras `model.save` 和 [`tf.keras.models.load_model`](https://tensorflow.google.cn/api_docs/python/tf/keras/models/load_model) * 低级 [`tf.saved_model.save`](https://tensorflow.google.cn/api_docs/python/tf/saved_model/save) 和 [`tf.saved_model.load`](https://tensorflow.google.cn/api_docs/python/tf/saved_model/load) ### Keras API 以下为使用 Keras API 保存和加载模型的示例: ```py keras_model_path = "/tmp/keras_save" model.save(keras_model_path) # save() should be called out of strategy scope ``` ```py WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. INFO:tensorflow:Assets written to: /tmp/keras_save/assets INFO:tensorflow:Assets written to: /tmp/keras_save/assets ``` 恢复无 [`tf.distribute.Strategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/Strategy) 的模型: ```py restored_keras_model = tf.keras.models.load_model(keras_model_path) restored_keras_model.fit(train_dataset, epochs=2) ``` ```py Epoch 1/2 938/938 [==============================] - 3s 3ms/step - loss: 0.0480 - accuracy: 0.0990 Epoch 2/2 938/938 [==============================] - 2s 2ms/step - loss: 0.0334 - accuracy: 0.0989 ``` 恢复模型后,您可以继续在它上面进行训练,甚至无需再次调用 `compile()`,因为在保存之前已经对其进行了编译。模型以 TensorFlow 的标准 `SavedModel` proto 格式保存。有关更多信息,请参阅 [`saved_model` 格式指南](https://tensorflow.google.cn/guide/saved_model)。 现在,加载模型并使用 [`tf.distribute.Strategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/Strategy) 进行训练: ```py another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0") with another_strategy.scope(): restored_keras_model_ds = tf.keras.models.load_model(keras_model_path) restored_keras_model_ds.fit(train_dataset, epochs=2) ``` ```py Epoch 1/2 938/938 [==============================] - 9s 9ms/step - loss: 0.0481 - accuracy: 0.0989 Epoch 2/2 938/938 [==============================] - 9s 9ms/step - loss: 0.0329 - accuracy: 0.0990 ``` 如您所见, [`tf.distribute.Strategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/Strategy) 可以按预期进行加载。此处使用的策略不必与保存前所用策略相同。 ### [`tf.saved_model`](https://tensorflow.google.cn/api_docs/python/tf/saved_model) API 现在,让我们看一下较低级别的 API。保存模型与 Keras API 类似: ```py model = get_model() # get a fresh model saved_model_path = "/tmp/tf_save" tf.saved_model.save(model, saved_model_path) ``` ```py INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets ``` 可以使用 [`tf.saved_model.load()`](https://tensorflow.google.cn/api_docs/python/tf/saved_model/load) 进行加载。但是,由于该 API 级别较低(因此用例范围更广泛),所以不会返回 Keras 模型。相反,它返回一个对象,其中包含可用于进行推断的函数。例如: ```py DEFAULT_FUNCTION_KEY = "serving_default" loaded = tf.saved_model.load(saved_model_path) inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY] ``` 加载的对象可能包含多个函数,每个函数与一个键关联。`"serving_default"` 是使用已保存的 Keras 模型的推断函数的默认键。要使用此函数进行推断,请运行以下代码: ```py predict_dataset = eval_dataset.map(lambda image, label: image) for batch in predict_dataset.take(1): print(inference_func(batch)) ``` ```py {'dense_3': } ``` 您还可以采用分布式方式加载和进行推断: ```py another_strategy = tf.distribute.MirroredStrategy() with another_strategy.scope(): loaded = tf.saved_model.load(saved_model_path) inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY] dist_predict_dataset = another_strategy.experimental_distribute_dataset( predict_dataset) # Calling the function in a distributed manner for batch in dist_predict_dataset: another_strategy.run(inference_func,args=(batch,)) ``` ```py INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. ``` 调用已恢复的函数只是基于已保存模型的前向传递(预测)。如果您想继续训练加载的函数,或者将加载的函数嵌入到更大的模型中,应如何操作? 通常的做法是将此加载对象包装到 Keras 层以实现此目的。幸运的是,[TF Hub](https://tensorflow.google.cn/hub) 为此提供了 [hub.KerasLayer](https://github.com/tensorflow/hub/blob/master/tensorflow_hub/keras_layer.py),如下所示: ```py import tensorflow_hub as hub def build_model(loaded): x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x') # Wrap what's loaded to a KerasLayer keras_layer = hub.KerasLayer(loaded, trainable=True)(x) model = tf.keras.Model(x, keras_layer) return model another_strategy = tf.distribute.MirroredStrategy() with another_strategy.scope(): loaded = tf.saved_model.load(saved_model_path) model = build_model(loaded) model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy']) model.fit(train_dataset, epochs=2) ``` ```py INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) Epoch 1/2 938/938 [==============================] - 3s 3ms/step - loss: 0.2059 - accuracy: 0.9393 Epoch 2/2 938/938 [==============================] - 3s 3ms/step - loss: 0.0681 - accuracy: 0.9799 ``` 如您所见,[`hub.KerasLayer`](https://tensorflow.google.cn/hub/api_docs/python/hub/KerasLayer) 可将从 [`tf.saved_model.load()`](https://tensorflow.google.cn/api_docs/python/tf/saved_model/load) 加载回的结果包装到可供构建其他模型的 Keras 层。这对于迁移学习非常实用。 ### 我应使用哪种 API? 对于保存,如果您使用的是 Keras 模型,那么始终建议使用 Keras 的 `model.save()` API。如果您所保存的不是 Keras 模型,那么您只能选择使用较低级的 API。 对于加载,使用哪种 API 取决于您要从加载的 API 中获得什么。如果您无法或不想获取 Keras 模型,请使用 [`tf.saved_model.load()`](https://tensorflow.google.cn/api_docs/python/tf/saved_model/load)。否则,请使用 [`tf.keras.models.load_model()`](https://tensorflow.google.cn/api_docs/python/tf/keras/models/load_model)。请注意,只有保存 Keras 模型后,才能恢复 Keras 模型。 可以混合使用 API。您可以使用 `model.save` 保存 Keras 模型,并使用低级 API [`tf.saved_model.load`](https://tensorflow.google.cn/api_docs/python/tf/saved_model/load) 加载非 Keras 模型。 ```py model = get_model() # Saving the model using Keras's save() API model.save(keras_model_path) another_strategy = tf.distribute.MirroredStrategy() # Loading the model using lower level API with another_strategy.scope(): loaded = tf.saved_model.load(keras_model_path) ``` ```py INFO:tensorflow:Assets written to: /tmp/keras_save/assets INFO:tensorflow:Assets written to: /tmp/keras_save/assets INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) ``` ### 警告 有一种特殊情况,您的 Keras 模型没有明确定义的输入。例如,可以创建没有任何输入形状的序贯模型 (`Sequential([Dense(3), ...]`)。子类化模型在初始化后也没有明确定义的输入。在这种情况下,在保存和加载时都应坚持使用较低级别的 API,否则会出现错误。 要检查您的模型是否具有明确定义的输入,只需检查 `model.inputs` 是否为 `None`。如果非 `None`,则一切正常。在 `.fit`、`.evaluate`、`.predict` 中使用模型,或调用模型 (`model(inputs)`) 时,输入形状将自动定义。 以下为示例: ```py class SubclassedModel(tf.keras.Model): output_name = 'output_layer' def __init__(self): super(SubclassedModel, self).__init__() self._dense_layer = tf.keras.layers.Dense( 5, dtype=tf.dtypes.float32, name=self.output_name) def call(self, inputs): return self._dense_layer(inputs) my_model = SubclassedModel() # my_model.save(keras_model_path) # ERROR! tf.saved_model.save(my_model, saved_model_path) ``` ```py WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f96b1c92320>, because it is not built. Warning:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f96b1c92320>, because it is not built. Warning:tensorflow:Skipping full serialization of Keras layer , because it is not built. Warning:tensorflow:Skipping full serialization of Keras layer , because it is not built. INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets ```