Files
notes_estom/Tensorflow/TensorFlow2.0/038.md
yinkanglong_lab 68c2dbc3ac apacheCNml&dl
2021-03-20 16:02:39 +08:00

524 lines
29 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 使用分布策略保存和加载模型
> 原文:[https://tensorflow.google.cn/tutorials/distribute/save_and_load](https://tensorflow.google.cn/tutorials/distribute/save_and_load)
## 概述
在训练期间一般需要保存和加载模型。有两组用于保存和加载 Keras 模型的 API高级 API 和低级 API。本教程演示了在使用 [`tf.distribute.Strategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/Strategy) 时如何使用 SavedModel API。要了解 SavedModel 和序列化的相关概况,请参阅[保存的模型指南](https://tensorflow.google.cn/guide/saved_model)和 [Keras 模型序列化指南](https://tensorflow.google.cn/guide/keras/save_and_serialize)。让我们从一个简单的示例开始:
导入依赖项:
```py
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()
```
使用 [`tf.distribute.Strategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/Strategy) 准备数据和模型:
```py
mirrored_strategy = tf.distribute.MirroredStrategy()
def get_data():
datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
return train_dataset, eval_dataset
def get_model():
with mirrored_strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
return model
```
```py
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
```
训练模型:
```py
model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
```
```py
Epoch 1/2
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
938/938 [==============================] - 4s 5ms/step - loss: 0.1971 - accuracy: 0.9421
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0662 - accuracy: 0.9801
<tensorflow.python.keras.callbacks.History at 0x7f96501659e8>
```
## 保存和加载模型
现在,您已经有一个简单的模型可供使用,让我们了解一下如何保存/加载 API。有两组可用的 API
* 高级 Keras `model.save` 和 [`tf.keras.models.load_model`](https://tensorflow.google.cn/api_docs/python/tf/keras/models/load_model)
* 低级 [`tf.saved_model.save`](https://tensorflow.google.cn/api_docs/python/tf/saved_model/save) 和 [`tf.saved_model.load`](https://tensorflow.google.cn/api_docs/python/tf/saved_model/load)
### Keras API
以下为使用 Keras API 保存和加载模型的示例:
```py
keras_model_path = "/tmp/keras_save"
model.save(keras_model_path) # save() should be called out of strategy scope
```
```py
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
```
恢复无 [`tf.distribute.Strategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/Strategy) 的模型:
```py
restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
```
```py
Epoch 1/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0480 - accuracy: 0.0990
Epoch 2/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0334 - accuracy: 0.0989
<tensorflow.python.keras.callbacks.History at 0x7f96c54d0a58>
```
恢复模型后,您可以继续在它上面进行训练,甚至无需再次调用 `compile()`,因为在保存之前已经对其进行了编译。模型以 TensorFlow 的标准 `SavedModel` proto 格式保存。有关更多信息,请参阅 [`saved_model` 格式指南](https://tensorflow.google.cn/guide/saved_model)。
现在,加载模型并使用 [`tf.distribute.Strategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/Strategy) 进行训练:
```py
another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
restored_keras_model_ds.fit(train_dataset, epochs=2)
```
```py
Epoch 1/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0481 - accuracy: 0.0989
Epoch 2/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0329 - accuracy: 0.0990
```
如您所见, [`tf.distribute.Strategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/Strategy) 可以按预期进行加载。此处使用的策略不必与保存前所用策略相同。
### [`tf.saved_model`](https://tensorflow.google.cn/api_docs/python/tf/saved_model) API
现在,让我们看一下较低级别的 API。保存模型与 Keras API 类似:
```py
model = get_model() # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
```
```py
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
```
可以使用 [`tf.saved_model.load()`](https://tensorflow.google.cn/api_docs/python/tf/saved_model/load) 进行加载。但是,由于该 API 级别较低(因此用例范围更广泛),所以不会返回 Keras 模型。相反,它返回一个对象,其中包含可用于进行推断的函数。例如:
```py
DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]
```
加载的对象可能包含多个函数,每个函数与一个键关联。`"serving_default"` 是使用已保存的 Keras 模型的推断函数的默认键。要使用此函数进行推断,请运行以下代码:
```py
predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
print(inference_func(batch))
```
```py
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[ 0.17218862, 0.07492599, -0.0548683 , 0.03503785, -0.03743191,
-0.05301537, 0.01267872, -0.02870197, -0.33800656, 0.17991678],
[ 0.12937182, -0.21557797, -0.09474514, 0.39076763, -0.22147779,
-0.1787742 , 0.2154337 , 0.00788027, -0.14960325, 0.43123117],
[ 0.04755233, -0.20264567, -0.17308846, 0.19781005, -0.11123425,
-0.4295108 , 0.05442019, 0.01459119, -0.17129104, 0.04688327],
[ 0.09866484, 0.01627818, -0.08671301, 0.05742932, -0.20312837,
-0.38836166, -0.06952551, 0.05141062, -0.03084616, 0.05498504],
[ 0.00565811, -0.04239772, 0.04898138, 0.06162139, -0.16708252,
-0.12976539, -0.00474121, 0.05431085, -0.14715545, 0.07582194],
[ 0.17589626, 0.19629489, -0.2076093 , 0.02031662, -0.1619812 ,
-0.24300966, -0.0310282 , -0.00850905, -0.18514219, 0.23665032],
[-0.02653 , -0.17737214, -0.24494407, 0.20125583, -0.17153463,
-0.18641792, 0.11408111, 0.01489197, -0.099539 , 0.41159016],
[ 0.1903163 , 0.1697292 , -0.14116906, 0.1588785 , -0.04286646,
-0.19863203, -0.04836996, -0.00679918, -0.14634813, 0.14979276],
[ 0.12109621, 0.03313948, -0.1955429 , 0.23528968, -0.12369496,
-0.20725062, 0.06024174, 0.05078189, -0.158943 , 0.16846842],
[ 0.16227934, 0.06379895, -0.08847713, 0.08261362, -0.03925761,
-0.17770812, -0.043965 , 0.02072081, -0.07430968, 0.05749936],
[ 0.05508922, -0.14091367, -0.1887006 , 0.12903523, -0.13182093,
-0.11879301, 0.20175044, 0.11686974, -0.1616871 , 0.2226192 ],
[ 0.18285918, -0.01880376, -0.15778637, 0.04477023, -0.22364017,
-0.23864916, -0.06328501, 0.04380857, -0.04448643, 0.40406597],
[ 0.04721744, 0.06619421, -0.10837474, 0.1292499 , -0.17490903,
-0.17313394, -0.06603841, 0.15658481, -0.09657097, -0.04059617],
[-0.04412666, 0.02258963, 0.08539917, 0.2561011 , -0.18279126,
-0.2519745 , -0.00787598, 0.08598025, -0.21961546, 0.10189874],
[ 0.05089861, 0.06746367, -0.13205 , 0.09160744, -0.30171782,
-0.25160635, 0.08317091, 0.03015741, -0.10570806, 0.28686398],
[ 0.13625176, -0.109529 , 0.04985618, 0.08199271, -0.24280871,
-0.22908798, 0.17737128, 0.09937412, -0.31234092, 0.2290439 ],
[ 0.13812706, 0.10425253, 0.0128724 , 0.12191941, -0.09126505,
-0.13897963, -0.17568447, 0.16489705, -0.26533198, 0.06911667],
[ 0.16982701, 0.087276 , -0.17102191, 0.06745699, -0.06239565,
-0.17226742, -0.02450407, 0.10939141, -0.13510445, 0.04026298],
[-0.05762933, 0.03908077, 0.0729831 , 0.12001946, -0.12699135,
-0.37191632, -0.10294843, 0.1815257 , -0.10121268, 0.06880292],
[ 0.07649058, -0.03354908, -0.06362928, -0.00831218, -0.24217641,
-0.11137463, 0.01944396, 0.0310707 , 0.0093919 , 0.34353036],
[ 0.16107717, -0.04705916, -0.14095825, 0.05297582, -0.1485554 ,
-0.12321693, 0.07225874, 0.07695273, -0.17055047, 0.22460693],
[ 0.02565719, -0.05495968, -0.11961621, 0.03014402, -0.1645109 ,
-0.26333475, 0.07536604, 0.04426918, -0.12448484, 0.04142715],
[ 0.02295595, 0.01484419, -0.28111714, 0.05291839, -0.09908111,
-0.22002876, 0.00388122, 0.06801579, -0.03227042, 0.04201593],
[ 0.01293404, -0.15113808, -0.05814568, 0.29754263, -0.13849238,
-0.02268202, 0.16958144, 0.12881759, -0.13463333, 0.3364867 ],
[ 0.19805974, -0.01798259, -0.12835501, 0.26842418, -0.04154617,
-0.19442351, -0.08115683, 0.08586816, 0.00582654, 0.04328927],
[ 0.09159922, 0.12617984, -0.15028486, 0.23344447, -0.06932314,
-0.1483246 , -0.02017963, 0.03262286, -0.2800941 , 0.18364596],
[ 0.1528 , 0.13280275, -0.09938447, 0.03614349, -0.1096218 ,
-0.19335787, -0.04933339, -0.02397237, -0.13356304, -0.01165973],
[ 0.13618907, 0.14891617, -0.16118397, 0.10435603, -0.1831438 ,
-0.16405147, -0.14186187, 0.12581114, -0.15762964, 0.13493878],
[ 0.05534358, -0.0916103 , 0.0352111 , 0.0020496 , -0.19224274,
-0.17663556, 0.08702807, -0.08016825, -0.14833373, 0.10739949],
[ 0.02660379, -0.04472145, 0.01165188, 0.0219909 , -0.16059823,
-0.26817566, -0.09790543, 0.10905766, -0.01595427, 0.304615 ],
[ 0.08248052, -0.09962849, -0.02325149, 0.04280585, -0.20835052,
-0.2023199 , -0.0130603 , 0.07936736, 0.0494375 , 0.27143508],
[ 0.00310345, 0.04583906, -0.20415008, 0.1876276 , -0.06600557,
-0.19580218, -0.02222047, 0.07650423, -0.08899002, 0.10885157],
[ 0.0783096 , -0.01651647, -0.09479928, 0.07058451, -0.14990349,
-0.33366078, 0.0564964 , 0.01118498, -0.14589244, 0.22603557],
[ 0.04565446, 0.05590308, -0.02989801, -0.07578284, -0.09796432,
-0.20807403, -0.00954358, 0.02622838, -0.10276475, -0.05590656],
[ 0.07286316, 0.01376749, -0.18262148, 0.28560585, -0.18269306,
-0.06166455, 0.12229253, 0.11880912, -0.08595768, 0.17080015],
[ 0.12635507, -0.0836257 , 0.03501946, 0.30507207, -0.34584454,
-0.29186884, 0.26327768, 0.18378039, -0.09220086, 0.16707191],
[ 0.11742169, 0.02937749, -0.16469768, 0.31997636, -0.1280521 ,
-0.17700416, 0.05593231, 0.05017062, -0.31535 , 0.15465745],
[ 0.08975917, 0.01203279, 0.09783987, 0.06205256, -0.05648104,
-0.27429107, -0.12651348, 0.09195078, -0.2890005 , 0.08270936],
[ 0.09477694, 0.10097383, -0.05783979, 0.11597094, -0.05375554,
-0.04229444, -0.09689695, 0.08121311, -0.05716637, 0.09075539],
[-0.04117738, -0.06426363, -0.0629988 , 0.00692648, -0.30303234,
-0.28447956, -0.01935545, 0.159902 , -0.10399745, 0.17079492],
[-0.01080875, -0.04450692, -0.19694453, 0.15313052, -0.11790004,
-0.21164687, 0.16064486, 0.05443045, 0.04431828, 0.18498638],
[ 0.16398555, 0.21772492, -0.03592323, 0.15181649, -0.02455682,
-0.28267485, -0.12445807, 0.17047536, -0.19300474, -0.01467199],
[ 0.04904355, -0.0152067 , 0.09667489, -0.01841408, -0.08439851,
-0.2905228 , -0.0541675 , 0.07489735, -0.13492545, 0.1839124 ],
[ 0.2369909 , 0.08534706, -0.12017098, 0.04527019, -0.05781246,
-0.1196178 , -0.09442404, 0.01685349, -0.26979008, 0.17579612],
[ 0.04441281, -0.09139308, 0.00063404, 0.02085789, -0.17478338,
-0.1746104 , 0.21254838, 0.07575508, -0.19009903, 0.26038024],
[ 0.23913413, 0.13267268, -0.11951514, 0.13184579, -0.11442515,
-0.1563474 , -0.13503158, 0.1639925 , -0.11313978, 0.05294855],
[ 0.11768216, 0.12213368, -0.00641227, 0.1983034 , -0.10263431,
-0.10918278, -0.06888436, 0.26294842, -0.1041921 , 0.09731302],
[ 0.16183744, -0.14602011, -0.17195675, 0.1428874 , -0.26739907,
-0.3048862 , 0.06860068, 0.03065268, -0.13347332, 0.4117231 ],
[-0.02206257, 0.00734324, 0.003649 , 0.12295016, -0.22801307,
-0.23414296, -0.03367008, 0.11127277, -0.01726604, -0.0447302 ],
[ 0.10106434, 0.09055474, -0.12789255, 0.1377592 , -0.05564225,
-0.21510065, -0.09061419, -0.0219887 , -0.14411387, -0.03950592],
[ 0.12847602, -0.09453006, -0.04503661, 0.27597424, -0.17524761,
-0.05134012, 0.16526361, 0.08649909, -0.22461002, 0.45229536],
[ 0.04311011, 0.09949236, -0.04975891, 0.22421105, -0.12030718,
-0.09846736, -0.1408607 , 0.2384947 , -0.21582088, 0.01464934],
[-0.03788627, 0.04636163, 0.07747708, 0.0814044 , -0.12896554,
-0.31223392, -0.0578138 , 0.1859979 , -0.10911787, 0.15140374],
[ 0.08929176, -0.02551255, -0.06947158, 0.25500187, -0.18166143,
-0.1110489 , 0.0658811 , 0.23209906, -0.00346252, 0.27463445],
[ 0.12721871, -0.05336493, -0.01648436, 0.23337078, -0.22428553,
-0.17424905, 0.03487325, 0.28687072, 0.04055911, 0.30594033],
[ 0.18656036, -0.00513786, -0.16282284, 0.02530107, -0.17092519,
-0.24259233, 0.05227455, 0.19966123, -0.28181344, 0.14443643],
[ 0.02111852, -0.04639132, -0.01641255, 0.20416623, -0.11734181,
-0.08085347, 0.13685697, 0.10490854, -0.09023371, 0.32988763],
[ 0.06382357, 0.02803485, 0.03532831, 0.07898249, -0.10290041,
-0.2603921 , -0.03376516, 0.09166428, -0.14019875, 0.19503292],
[ 0.15105441, 0.0064583 , -0.1603775 , 0.16818096, -0.22179885,
-0.36698502, 0.12694073, -0.1294238 , -0.21702135, 0.34743598],
[ 0.11475793, -0.08016841, -0.19020993, 0.27748483, -0.13198294,
-0.22254312, 0.19926155, 0.19124901, -0.08933976, 0.25242418],
[ 0.09380357, -0.02989926, -0.01782445, 0.00312767, -0.02519768,
-0.43802148, -0.00290839, 0.04753356, -0.02965541, 0.10304467],
[ 0.20286047, -0.07675526, -0.03217752, 0.17366095, -0.13799758,
-0.27491322, 0.00279245, 0.14233288, -0.05951798, 0.36937428],
[ 0.01445094, -0.07265921, 0.10096341, 0.17594802, -0.17472097,
-0.2958681 , 0.0036519 , 0.03119059, -0.2027646 , -0.01793122],
[-0.02391969, -0.10441571, -0.00624696, 0.06563509, -0.14965585,
-0.3743796 , 0.0422266 , 0.04684277, 0.05023851, -0.07264638]],
dtype=float32)>}
```
您还可以采用分布式方式加载和进行推断:
```py
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]
dist_predict_dataset = another_strategy.experimental_distribute_dataset(
predict_dataset)
# Calling the function in a distributed manner
for batch in dist_predict_dataset:
another_strategy.run(inference_func,args=(batch,))
```
```py
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
```
调用已恢复的函数只是基于已保存模型的前向传递(预测)。如果您想继续训练加载的函数,或者将加载的函数嵌入到更大的模型中,应如何操作? 通常的做法是将此加载对象包装到 Keras 层以实现此目的。幸运的是,[TF Hub](https://tensorflow.google.cn/hub) 为此提供了 [hub.KerasLayer](https://github.com/tensorflow/hub/blob/master/tensorflow_hub/keras_layer.py),如下所示:
```py
import tensorflow_hub as hub
def build_model(loaded):
x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
# Wrap what's loaded to a KerasLayer
keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
model = tf.keras.Model(x, keras_layer)
return model
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
loaded = tf.saved_model.load(saved_model_path)
model = build_model(loaded)
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
model.fit(train_dataset, epochs=2)
```
```py
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Epoch 1/2
938/938 [==============================] - 3s 3ms/step - loss: 0.2059 - accuracy: 0.9393
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0681 - accuracy: 0.9799
```
如您所见,[`hub.KerasLayer`](https://tensorflow.google.cn/hub/api_docs/python/hub/KerasLayer) 可将从 [`tf.saved_model.load()`](https://tensorflow.google.cn/api_docs/python/tf/saved_model/load) 加载回的结果包装到可供构建其他模型的 Keras 层。这对于迁移学习非常实用。
### 我应使用哪种 API
对于保存,如果您使用的是 Keras 模型,那么始终建议使用 Keras 的 `model.save()` API。如果您所保存的不是 Keras 模型,那么您只能选择使用较低级的 API。
对于加载,使用哪种 API 取决于您要从加载的 API 中获得什么。如果您无法或不想获取 Keras 模型,请使用 [`tf.saved_model.load()`](https://tensorflow.google.cn/api_docs/python/tf/saved_model/load)。否则,请使用 [`tf.keras.models.load_model()`](https://tensorflow.google.cn/api_docs/python/tf/keras/models/load_model)。请注意,只有保存 Keras 模型后,才能恢复 Keras 模型。
可以混合使用 API。您可以使用 `model.save` 保存 Keras 模型,并使用低级 API [`tf.saved_model.load`](https://tensorflow.google.cn/api_docs/python/tf/saved_model/load) 加载非 Keras 模型。
```py
model = get_model()
# Saving the model using Keras's save() API
model.save(keras_model_path)
another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
loaded = tf.saved_model.load(keras_model_path)
```
```py
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
```
### 警告
有一种特殊情况,您的 Keras 模型没有明确定义的输入。例如,可以创建没有任何输入形状的序贯模型 (`Sequential([Dense(3), ...]`)。子类化模型在初始化后也没有明确定义的输入。在这种情况下,在保存和加载时都应坚持使用较低级别的 API否则会出现错误。
要检查您的模型是否具有明确定义的输入,只需检查 `model.inputs` 是否为 `None`。如果非 `None`,则一切正常。在 `.fit``.evaluate``.predict` 中使用模型,或调用模型 (`model(inputs)`) 时,输入形状将自动定义。
以下为示例:
```py
class SubclassedModel(tf.keras.Model):
output_name = 'output_layer'
def __init__(self):
super(SubclassedModel, self).__init__()
self._dense_layer = tf.keras.layers.Dense(
5, dtype=tf.dtypes.float32, name=self.output_name)
def call(self, inputs):
return self._dense_layer(inputs)
my_model = SubclassedModel()
# my_model.save(keras_model_path) # ERROR!
tf.saved_model.save(my_model, saved_model_path)
```
```py
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f96b1c92320>, because it is not built.
Warning:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f96b1c92320>, because it is not built.
Warning:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f96b1c92b70>, because it is not built.
Warning:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f96b1c92b70>, because it is not built.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
```