mirror of
https://github.com/Estom/notes.git
synced 2026-04-03 02:49:25 +08:00
450 lines
20 KiB
Markdown
450 lines
20 KiB
Markdown
# Keras 的分布式训练
|
||
|
||
> 原文:[https://tensorflow.google.cn/tutorials/distribute/keras](https://tensorflow.google.cn/tutorials/distribute/keras)
|
||
|
||
<devsite-mathjax config="TeX-AMS-MML_SVG"></devsite-mathjax>
|
||
|
||
**Note:** 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为, 所以无法保证它们是最准确的,并且反映了最新的 [官方英文文档](https://tensorflow.google.cn/?hl=en)。如果您有改进此翻译的建议, 请提交 pull request 到 [tensorflow/docs](https://github.com/tensorflow/docs) GitHub 仓库。要志愿地撰写或者审核译文,请加入 [docs-zh-cn@tensorflow.org Google Group](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-zh-cn)。
|
||
|
||
## 概述
|
||
|
||
[`tf.distribute.Strategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/Strategy) API 提供了一个抽象的 API ,用于跨多个处理单元(processing units)分布式训练。它的目的是允许用户使用现有模型和训练代码,只需要很少的修改,就可以启用分布式训练。
|
||
|
||
本教程使用 [`tf.distribute.MirroredStrategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/MirroredStrategy),这是在一台计算机上的多 GPU(单机多卡)进行同时训练的图形内复制(in-graph replication)。事实上,它会将所有模型的变量复制到每个处理器上,然后,通过使用 [all-reduce](http://mpitutorial.com/tutorials/mpi-reduce-and-allreduce/) 去整合所有处理器的梯度(gradients),并将整合的结果应用于所有副本之中。
|
||
|
||
`MirroredStategy` 是 tensorflow 中可用的几种分发策略之一。 您可以在 [分发策略指南](https://tensorflow.google.cn/guide/distribute_strategy) 中阅读更多分发策略。
|
||
|
||
### Keras API
|
||
|
||
这个例子使用 [`tf.keras`](https://tensorflow.google.cn/api_docs/python/tf/keras) API 去构建和训练模型。 关于自定义训练模型,请参阅 [tf.distribute.Strategy with training loops](/tutorials/distribute/training_loops) 教程。
|
||
|
||
## 导入依赖
|
||
|
||
```py
|
||
# 导入 TensorFlow 和 TensorFlow 数据集
|
||
|
||
import tensorflow_datasets as tfds
|
||
import tensorflow as tf
|
||
tfds.disable_progress_bar()
|
||
|
||
import os
|
||
```
|
||
|
||
```py
|
||
print(tf.__version__)
|
||
```
|
||
|
||
```py
|
||
2.3.0
|
||
|
||
```
|
||
|
||
## 下载数据集
|
||
|
||
下载 MNIST 数据集并从 [TensorFlow Datasets](https://tensorflow.google.cn/datasets) 加载。 这会返回 [`tf.data`](https://tensorflow.google.cn/api_docs/python/tf/data) 格式的数据集。
|
||
|
||
将 `with_info` 设置为 `True` 会包含整个数据集的元数据,其中这些数据集将保存在 `info` 中。 除此之外,该元数据对象包括训练和测试示例的数量。
|
||
|
||
```py
|
||
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
|
||
|
||
mnist_train, mnist_test = datasets['train'], datasets['test']
|
||
```
|
||
|
||
## 定义分配策略
|
||
|
||
创建一个 `MirroredStrategy` 对象。这将处理分配策略,并提供一个上下文管理器([`tf.distribute.MirroredStrategy.scope`](https://tensorflow.google.cn/api_docs/python/tf/distribute/MirroredStrategy#scope))来构建你的模型。
|
||
|
||
```py
|
||
strategy = tf.distribute.MirroredStrategy()
|
||
```
|
||
|
||
```py
|
||
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
|
||
|
||
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
|
||
|
||
```
|
||
|
||
```py
|
||
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
|
||
```
|
||
|
||
```py
|
||
Number of devices: 1
|
||
|
||
```
|
||
|
||
## 设置输入管道(pipeline)
|
||
|
||
在训练具有多个 GPU 的模型时,您可以通过增加批量大小(batch size)来有效地使用额外的计算能力。通常来说,使用适合 GPU 内存的最大批量大小(batch size),并相应地调整学习速率。
|
||
|
||
```py
|
||
# 您还可以执行 info.splits.total_num_examples 来获取总数
|
||
# 数据集中的样例数量。
|
||
|
||
num_train_examples = info.splits['train'].num_examples
|
||
num_test_examples = info.splits['test'].num_examples
|
||
|
||
BUFFER_SIZE = 10000
|
||
|
||
BATCH_SIZE_PER_REPLICA = 64
|
||
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
|
||
```
|
||
|
||
0-255 的像素值, [必须标准化到 0-1 范围](https://en.wikipedia.org/wiki/Feature_scaling)。在函数中定义标准化。
|
||
|
||
```py
|
||
def scale(image, label):
|
||
image = tf.cast(image, tf.float32)
|
||
image /= 255
|
||
|
||
return image, label
|
||
```
|
||
|
||
将此功能应用于训练和测试数据,随机打乱训练数据,并[批量训练](https://tensorflow.google.cn/api_docs/python/tf/data/Dataset#batch)。 请注意,我们还保留了训练数据的内存缓存以提高性能。
|
||
|
||
```py
|
||
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
|
||
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
|
||
```
|
||
|
||
## 生成模型
|
||
|
||
在 `strategy.scope` 的上下文中创建和编译 Keras 模型。
|
||
|
||
```py
|
||
with strategy.scope():
|
||
model = tf.keras.Sequential([
|
||
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
|
||
tf.keras.layers.MaxPooling2D(),
|
||
tf.keras.layers.Flatten(),
|
||
tf.keras.layers.Dense(64, activation='relu'),
|
||
tf.keras.layers.Dense(10)
|
||
])
|
||
|
||
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
||
optimizer=tf.keras.optimizers.Adam(),
|
||
metrics=['accuracy'])
|
||
```
|
||
|
||
## 定义回调(callback)
|
||
|
||
这里使用的回调(callbacks)是:
|
||
|
||
* *TensorBoard*: 此回调(callbacks)为 TensorBoard 写入日志,允许您可视化图形。
|
||
* *Model Checkpoint*: 此回调(callbacks)在每个 epoch 后保存模型。
|
||
* *Learning Rate Scheduler*: 使用此回调(callbacks),您可以安排学习率在每个 epoch/batch 之后更改。
|
||
|
||
为了便于说明,添加打印回调(callbacks)以在笔记本中显示*学习率*。
|
||
|
||
```py
|
||
# 定义检查点(checkpoint)目录以存储检查点(checkpoints)
|
||
|
||
checkpoint_dir = './training_checkpoints'
|
||
# 检查点(checkpoint)文件的名称
|
||
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
|
||
```
|
||
|
||
```py
|
||
# 衰减学习率的函数。
|
||
# 您可以定义所需的任何衰减函数。
|
||
def decay(epoch):
|
||
if epoch < 3:
|
||
return 1e-3
|
||
elif epoch >= 3 and epoch < 7:
|
||
return 1e-4
|
||
else:
|
||
return 1e-5
|
||
```
|
||
|
||
```py
|
||
# 在每个 epoch 结束时打印 LR 的回调(callbacks)。
|
||
class PrintLR(tf.keras.callbacks.Callback):
|
||
def on_epoch_end(self, epoch, logs=None):
|
||
print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
|
||
model.optimizer.lr.numpy()))
|
||
```
|
||
|
||
```py
|
||
callbacks = [
|
||
tf.keras.callbacks.TensorBoard(log_dir='./logs'),
|
||
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
|
||
save_weights_only=True),
|
||
tf.keras.callbacks.LearningRateScheduler(decay),
|
||
PrintLR()
|
||
]
|
||
```
|
||
|
||
## 训练和评估
|
||
|
||
在该部分,以普通的方式训练模型,在模型上调用 `fit` 并传入在教程开始时创建的数据集。 无论您是否分布式训练,此步骤都是相同的。
|
||
|
||
```py
|
||
model.fit(train_dataset, epochs=12, callbacks=callbacks)
|
||
```
|
||
|
||
```py
|
||
Epoch 1/12
|
||
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
|
||
Instructions for updating:
|
||
Use `tf.data.Iterator.get_next_as_optional()` instead.
|
||
|
||
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
|
||
Instructions for updating:
|
||
Use `tf.data.Iterator.get_next_as_optional()` instead.
|
||
|
||
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
|
||
|
||
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
|
||
|
||
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
|
||
|
||
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
|
||
|
||
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
|
||
|
||
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
|
||
|
||
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
|
||
|
||
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
|
||
|
||
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
|
||
|
||
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
|
||
|
||
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
|
||
|
||
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
|
||
|
||
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
|
||
|
||
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
|
||
|
||
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
|
||
|
||
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
|
||
|
||
1/938 [..............................] - ETA: 0s - loss: 2.3194 - accuracy: 0.0938WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.
|
||
Instructions for updating:
|
||
use `tf.profiler.experimental.stop` instead.
|
||
|
||
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.
|
||
Instructions for updating:
|
||
use `tf.profiler.experimental.stop` instead.
|
||
|
||
Warning:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0046s vs `on_train_batch_end` time: 0.0296s). Check your callbacks.
|
||
|
||
Warning:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0046s vs `on_train_batch_end` time: 0.0296s). Check your callbacks.
|
||
|
||
932/938 [============================>.] - ETA: 0s - loss: 0.2055 - accuracy: 0.9422
|
||
Learning rate for epoch 1 is 0.0010000000474974513
|
||
938/938 [==============================] - 4s 5ms/step - loss: 0.2049 - accuracy: 0.9424
|
||
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
|
||
|
||
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
|
||
|
||
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
|
||
|
||
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
|
||
|
||
Epoch 2/12
|
||
922/938 [============================>.] - ETA: 0s - loss: 0.0681 - accuracy: 0.9797
|
||
Learning rate for epoch 2 is 0.0010000000474974513
|
||
938/938 [==============================] - 3s 3ms/step - loss: 0.0680 - accuracy: 0.9798
|
||
Epoch 3/12
|
||
930/938 [============================>.] - ETA: 0s - loss: 0.0484 - accuracy: 0.9855
|
||
Learning rate for epoch 3 is 0.0010000000474974513
|
||
938/938 [==============================] - 3s 3ms/step - loss: 0.0484 - accuracy: 0.9855
|
||
Epoch 4/12
|
||
920/938 [============================>.] - ETA: 0s - loss: 0.0277 - accuracy: 0.9925
|
||
Learning rate for epoch 4 is 9.999999747378752e-05
|
||
938/938 [==============================] - 3s 3ms/step - loss: 0.0276 - accuracy: 0.9926
|
||
Epoch 5/12
|
||
931/938 [============================>.] - ETA: 0s - loss: 0.0248 - accuracy: 0.9935
|
||
Learning rate for epoch 5 is 9.999999747378752e-05
|
||
938/938 [==============================] - 3s 3ms/step - loss: 0.0247 - accuracy: 0.9936
|
||
Epoch 6/12
|
||
931/938 [============================>.] - ETA: 0s - loss: 0.0231 - accuracy: 0.9938
|
||
Learning rate for epoch 6 is 9.999999747378752e-05
|
||
938/938 [==============================] - 3s 3ms/step - loss: 0.0230 - accuracy: 0.9938
|
||
Epoch 7/12
|
||
936/938 [============================>.] - ETA: 0s - loss: 0.0217 - accuracy: 0.9941
|
||
Learning rate for epoch 7 is 9.999999747378752e-05
|
||
938/938 [==============================] - 3s 3ms/step - loss: 0.0216 - accuracy: 0.9941
|
||
Epoch 8/12
|
||
932/938 [============================>.] - ETA: 0s - loss: 0.0189 - accuracy: 0.9952
|
||
Learning rate for epoch 8 is 9.999999747378752e-06
|
||
938/938 [==============================] - 3s 3ms/step - loss: 0.0189 - accuracy: 0.9952
|
||
Epoch 9/12
|
||
932/938 [============================>.] - ETA: 0s - loss: 0.0188 - accuracy: 0.9953
|
||
Learning rate for epoch 9 is 9.999999747378752e-06
|
||
938/938 [==============================] - 3s 3ms/step - loss: 0.0187 - accuracy: 0.9953
|
||
Epoch 10/12
|
||
932/938 [============================>.] - ETA: 0s - loss: 0.0185 - accuracy: 0.9953
|
||
Learning rate for epoch 10 is 9.999999747378752e-06
|
||
938/938 [==============================] - 3s 3ms/step - loss: 0.0185 - accuracy: 0.9953
|
||
Epoch 11/12
|
||
934/938 [============================>.] - ETA: 0s - loss: 0.0183 - accuracy: 0.9953
|
||
Learning rate for epoch 11 is 9.999999747378752e-06
|
||
938/938 [==============================] - 3s 3ms/step - loss: 0.0184 - accuracy: 0.9953
|
||
Epoch 12/12
|
||
931/938 [============================>.] - ETA: 0s - loss: 0.0183 - accuracy: 0.9954
|
||
Learning rate for epoch 12 is 9.999999747378752e-06
|
||
938/938 [==============================] - 3s 3ms/step - loss: 0.0182 - accuracy: 0.9955
|
||
|
||
<tensorflow.python.keras.callbacks.History at 0x7fe470118978>
|
||
|
||
```
|
||
|
||
如下所示,检查点(checkpoint)将被保存。
|
||
|
||
```py
|
||
# 检查检查点(checkpoint)目录
|
||
ls {checkpoint_dir}
|
||
|
||
```
|
||
|
||
```py
|
||
checkpoint ckpt_4.data-00000-of-00001
|
||
ckpt_1.data-00000-of-00001 ckpt_4.index
|
||
ckpt_1.index ckpt_5.data-00000-of-00001
|
||
ckpt_10.data-00000-of-00001 ckpt_5.index
|
||
ckpt_10.index ckpt_6.data-00000-of-00001
|
||
ckpt_11.data-00000-of-00001 ckpt_6.index
|
||
ckpt_11.index ckpt_7.data-00000-of-00001
|
||
ckpt_12.data-00000-of-00001 ckpt_7.index
|
||
ckpt_12.index ckpt_8.data-00000-of-00001
|
||
ckpt_2.data-00000-of-00001 ckpt_8.index
|
||
ckpt_2.index ckpt_9.data-00000-of-00001
|
||
ckpt_3.data-00000-of-00001 ckpt_9.index
|
||
ckpt_3.index
|
||
|
||
```
|
||
|
||
要查看模型的执行方式,请加载最新的检查点(checkpoint)并在测试数据上调用 `evaluate` 。
|
||
|
||
使用适当的数据集调用 `evaluate` 。
|
||
|
||
```py
|
||
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
|
||
|
||
eval_loss, eval_acc = model.evaluate(eval_dataset)
|
||
|
||
print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
|
||
```
|
||
|
||
```py
|
||
157/157 [==============================] - 1s 6ms/step - loss: 0.0399 - accuracy: 0.9861
|
||
Eval loss: 0.03988004848361015, Eval Accuracy: 0.9861000180244446
|
||
|
||
```
|
||
|
||
要查看输出,您可以在终端下载并查看 TensorBoard 日志。
|
||
|
||
```py
|
||
$ tensorboard --logdir=path/to/log-directory
|
||
```
|
||
|
||
```py
|
||
ls -sh ./logs
|
||
|
||
```
|
||
|
||
```py
|
||
total 4.0K
|
||
4.0K train
|
||
|
||
```
|
||
|
||
## 导出到 SavedModel
|
||
|
||
将图形和变量导出为与平台无关的 SavedModel 格式。 保存模型后,可以在有或没有 scope 的情况下加载模型。
|
||
|
||
```py
|
||
path = 'saved_model/'
|
||
```
|
||
|
||
```py
|
||
model.save(path, save_format='tf')
|
||
```
|
||
|
||
```py
|
||
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
|
||
Instructions for updating:
|
||
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
|
||
|
||
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
|
||
Instructions for updating:
|
||
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
|
||
|
||
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
|
||
Instructions for updating:
|
||
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
|
||
|
||
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
|
||
Instructions for updating:
|
||
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
|
||
|
||
INFO:tensorflow:Assets written to: saved_model/assets
|
||
|
||
INFO:tensorflow:Assets written to: saved_model/assets
|
||
|
||
```
|
||
|
||
在无需 `strategy.scope` 加载模型。
|
||
|
||
```py
|
||
unreplicated_model = tf.keras.models.load_model(path)
|
||
|
||
unreplicated_model.compile(
|
||
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
||
optimizer=tf.keras.optimizers.Adam(),
|
||
metrics=['accuracy'])
|
||
|
||
eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)
|
||
|
||
print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
|
||
```
|
||
|
||
```py
|
||
157/157 [==============================] - 1s 3ms/step - loss: 0.0399 - accuracy: 0.9861
|
||
Eval loss: 0.03988004848361015, Eval Accuracy: 0.9861000180244446
|
||
|
||
```
|
||
|
||
在含 `strategy.scope` 加载模型。
|
||
|
||
```py
|
||
with strategy.scope():
|
||
replicated_model = tf.keras.models.load_model(path)
|
||
replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
||
optimizer=tf.keras.optimizers.Adam(),
|
||
metrics=['accuracy'])
|
||
|
||
eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
|
||
print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
|
||
```
|
||
|
||
```py
|
||
157/157 [==============================] - 1s 5ms/step - loss: 0.0399 - accuracy: 0.9861
|
||
Eval loss: 0.03988004848361015, Eval Accuracy: 0.9861000180244446
|
||
|
||
```
|
||
|
||
### 示例和教程
|
||
|
||
以下是使用 keras fit/compile 分布式策略的一些示例:
|
||
|
||
1. 使用[`tf.distribute.MirroredStrategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/MirroredStrategy) 训练 [Transformer](https://github.com/tensorflow/models/blob/master/official/nlp/transformer/transformer_main.py) 的示例。
|
||
2. 使用[`tf.distribute.MirroredStrategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/MirroredStrategy) 训练 [NCF](https://github.com/tensorflow/models/blob/master/official/recommendation/ncf_keras_main.py) 的示例。
|
||
|
||
[分布式策略指南](https://tensorflow.google.cn/guide/distribute_strategy#examples_and_tutorials)中列出的更多示例
|
||
|
||
## 下一步
|
||
|
||
* 阅读[分布式策略指南](https://tensorflow.google.cn/guide/distribute_strategy)。
|
||
* 阅读[自定义训练的分布式训练](/tutorials/distribute/training_loops)教程。
|
||
|
||
注意:[`tf.distribute.Strategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/Strategy) 正在积极开发中,我们将在不久的将来添加更多示例和教程。欢迎您进行尝试。我们欢迎您通过 [GitHub 上的 issue](https://github.com/tensorflow/tensorflow/issues/new) 提供反馈。 |