Files
notes_estom/Tensorflow/TensorFlow2.0/034.md
yinkanglong_lab 68c2dbc3ac apacheCNml&dl
2021-03-20 16:02:39 +08:00

450 lines
20 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Keras 的分布式训练
> 原文:[https://tensorflow.google.cn/tutorials/distribute/keras](https://tensorflow.google.cn/tutorials/distribute/keras)
<devsite-mathjax config="TeX-AMS-MML_SVG"></devsite-mathjax>
**Note:** 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为, 所以无法保证它们是最准确的,并且反映了最新的 [官方英文文档](https://tensorflow.google.cn/?hl=en)。如果您有改进此翻译的建议, 请提交 pull request 到 [tensorflow/docs](https://github.com/tensorflow/docs) GitHub 仓库。要志愿地撰写或者审核译文,请加入 [docs-zh-cn@tensorflow.org Google Group](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-zh-cn)。
## 概述
[`tf.distribute.Strategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/Strategy) API 提供了一个抽象的 API 用于跨多个处理单元processing units分布式训练。它的目的是允许用户使用现有模型和训练代码只需要很少的修改就可以启用分布式训练。
本教程使用 [`tf.distribute.MirroredStrategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/MirroredStrategy),这是在一台计算机上的多 GPU单机多卡进行同时训练的图形内复制in-graph replication。事实上它会将所有模型的变量复制到每个处理器上然后通过使用 [all-reduce](http://mpitutorial.com/tutorials/mpi-reduce-and-allreduce/) 去整合所有处理器的梯度gradients并将整合的结果应用于所有副本之中。
`MirroredStategy` 是 tensorflow 中可用的几种分发策略之一。 您可以在 [分发策略指南](https://tensorflow.google.cn/guide/distribute_strategy) 中阅读更多分发策略。
### Keras API
这个例子使用 [`tf.keras`](https://tensorflow.google.cn/api_docs/python/tf/keras) API 去构建和训练模型。 关于自定义训练模型,请参阅 [tf.distribute.Strategy with training loops](/tutorials/distribute/training_loops) 教程。
## 导入依赖
```py
# 导入 TensorFlow 和 TensorFlow 数据集
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()
import os
```
```py
print(tf.__version__)
```
```py
2.3.0
```
## 下载数据集
下载 MNIST 数据集并从 [TensorFlow Datasets](https://tensorflow.google.cn/datasets) 加载。 这会返回 [`tf.data`](https://tensorflow.google.cn/api_docs/python/tf/data) 格式的数据集。
`with_info` 设置为 `True` 会包含整个数据集的元数据,其中这些数据集将保存在 `info` 中。 除此之外,该元数据对象包括训练和测试示例的数量。
```py
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
```
## 定义分配策略
创建一个 `MirroredStrategy` 对象。这将处理分配策略,并提供一个上下文管理器([`tf.distribute.MirroredStrategy.scope`](https://tensorflow.google.cn/api_docs/python/tf/distribute/MirroredStrategy#scope))来构建你的模型。
```py
strategy = tf.distribute.MirroredStrategy()
```
```py
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
```
```py
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
```
```py
Number of devices: 1
```
## 设置输入管道pipeline
在训练具有多个 GPU 的模型时您可以通过增加批量大小batch size来有效地使用额外的计算能力。通常来说使用适合 GPU 内存的最大批量大小batch size并相应地调整学习速率。
```py
# 您还可以执行 info.splits.total_num_examples 来获取总数
# 数据集中的样例数量。
num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
```
0-255 的像素值, [必须标准化到 0-1 范围](https://en.wikipedia.org/wiki/Feature_scaling)。在函数中定义标准化。
```py
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
```
将此功能应用于训练和测试数据,随机打乱训练数据,并[批量训练](https://tensorflow.google.cn/api_docs/python/tf/data/Dataset#batch)。 请注意,我们还保留了训练数据的内存缓存以提高性能。
```py
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
```
## 生成模型
`strategy.scope` 的上下文中创建和编译 Keras 模型。
```py
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
```
## 定义回调callback
这里使用的回调callbacks
* *TensorBoard*: 此回调callbacks为 TensorBoard 写入日志,允许您可视化图形。
* *Model Checkpoint*: 此回调callbacks在每个 epoch 后保存模型。
* *Learning Rate Scheduler*: 使用此回调callbacks您可以安排学习率在每个 epoch/batch 之后更改。
为了便于说明添加打印回调callbacks以在笔记本中显示*学习率*。
```py
# 定义检查点checkpoint目录以存储检查点checkpoints
checkpoint_dir = './training_checkpoints'
# 检查点checkpoint文件的名称
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
```
```py
# 衰减学习率的函数。
# 您可以定义所需的任何衰减函数。
def decay(epoch):
if epoch < 3:
return 1e-3
elif epoch >= 3 and epoch < 7:
return 1e-4
else:
return 1e-5
```
```py
# 在每个 epoch 结束时打印 LR 的回调callbacks
class PrintLR(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
model.optimizer.lr.numpy()))
```
```py
callbacks = [
tf.keras.callbacks.TensorBoard(log_dir='./logs'),
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
save_weights_only=True),
tf.keras.callbacks.LearningRateScheduler(decay),
PrintLR()
]
```
## 训练和评估
在该部分,以普通的方式训练模型,在模型上调用 `fit` 并传入在教程开始时创建的数据集。 无论您是否分布式训练,此步骤都是相同的。
```py
model.fit(train_dataset, epochs=12, callbacks=callbacks)
```
```py
Epoch 1/12
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
1/938 [..............................] - ETA: 0s - loss: 2.3194 - accuracy: 0.0938WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.
Instructions for updating:
use `tf.profiler.experimental.stop` instead.
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.
Instructions for updating:
use `tf.profiler.experimental.stop` instead.
Warning:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0046s vs `on_train_batch_end` time: 0.0296s). Check your callbacks.
Warning:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0046s vs `on_train_batch_end` time: 0.0296s). Check your callbacks.
932/938 [============================>.] - ETA: 0s - loss: 0.2055 - accuracy: 0.9422
Learning rate for epoch 1 is 0.0010000000474974513
938/938 [==============================] - 4s 5ms/step - loss: 0.2049 - accuracy: 0.9424
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 2/12
922/938 [============================>.] - ETA: 0s - loss: 0.0681 - accuracy: 0.9797
Learning rate for epoch 2 is 0.0010000000474974513
938/938 [==============================] - 3s 3ms/step - loss: 0.0680 - accuracy: 0.9798
Epoch 3/12
930/938 [============================>.] - ETA: 0s - loss: 0.0484 - accuracy: 0.9855
Learning rate for epoch 3 is 0.0010000000474974513
938/938 [==============================] - 3s 3ms/step - loss: 0.0484 - accuracy: 0.9855
Epoch 4/12
920/938 [============================>.] - ETA: 0s - loss: 0.0277 - accuracy: 0.9925
Learning rate for epoch 4 is 9.999999747378752e-05
938/938 [==============================] - 3s 3ms/step - loss: 0.0276 - accuracy: 0.9926
Epoch 5/12
931/938 [============================>.] - ETA: 0s - loss: 0.0248 - accuracy: 0.9935
Learning rate for epoch 5 is 9.999999747378752e-05
938/938 [==============================] - 3s 3ms/step - loss: 0.0247 - accuracy: 0.9936
Epoch 6/12
931/938 [============================>.] - ETA: 0s - loss: 0.0231 - accuracy: 0.9938
Learning rate for epoch 6 is 9.999999747378752e-05
938/938 [==============================] - 3s 3ms/step - loss: 0.0230 - accuracy: 0.9938
Epoch 7/12
936/938 [============================>.] - ETA: 0s - loss: 0.0217 - accuracy: 0.9941
Learning rate for epoch 7 is 9.999999747378752e-05
938/938 [==============================] - 3s 3ms/step - loss: 0.0216 - accuracy: 0.9941
Epoch 8/12
932/938 [============================>.] - ETA: 0s - loss: 0.0189 - accuracy: 0.9952
Learning rate for epoch 8 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - loss: 0.0189 - accuracy: 0.9952
Epoch 9/12
932/938 [============================>.] - ETA: 0s - loss: 0.0188 - accuracy: 0.9953
Learning rate for epoch 9 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - loss: 0.0187 - accuracy: 0.9953
Epoch 10/12
932/938 [============================>.] - ETA: 0s - loss: 0.0185 - accuracy: 0.9953
Learning rate for epoch 10 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - loss: 0.0185 - accuracy: 0.9953
Epoch 11/12
934/938 [============================>.] - ETA: 0s - loss: 0.0183 - accuracy: 0.9953
Learning rate for epoch 11 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - loss: 0.0184 - accuracy: 0.9953
Epoch 12/12
931/938 [============================>.] - ETA: 0s - loss: 0.0183 - accuracy: 0.9954
Learning rate for epoch 12 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - loss: 0.0182 - accuracy: 0.9955
<tensorflow.python.keras.callbacks.History at 0x7fe470118978>
```
如下所示检查点checkpoint将被保存。
```py
# 检查检查点checkpoint目录
ls {checkpoint_dir}
```
```py
checkpoint ckpt_4.data-00000-of-00001
ckpt_1.data-00000-of-00001 ckpt_4.index
ckpt_1.index ckpt_5.data-00000-of-00001
ckpt_10.data-00000-of-00001 ckpt_5.index
ckpt_10.index ckpt_6.data-00000-of-00001
ckpt_11.data-00000-of-00001 ckpt_6.index
ckpt_11.index ckpt_7.data-00000-of-00001
ckpt_12.data-00000-of-00001 ckpt_7.index
ckpt_12.index ckpt_8.data-00000-of-00001
ckpt_2.data-00000-of-00001 ckpt_8.index
ckpt_2.index ckpt_9.data-00000-of-00001
ckpt_3.data-00000-of-00001 ckpt_9.index
ckpt_3.index
```
要查看模型的执行方式请加载最新的检查点checkpoint并在测试数据上调用 `evaluate`
使用适当的数据集调用 `evaluate`
```py
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
eval_loss, eval_acc = model.evaluate(eval_dataset)
print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
```
```py
157/157 [==============================] - 1s 6ms/step - loss: 0.0399 - accuracy: 0.9861
Eval loss: 0.03988004848361015, Eval Accuracy: 0.9861000180244446
```
要查看输出,您可以在终端下载并查看 TensorBoard 日志。
```py
$ tensorboard --logdir=path/to/log-directory
```
```py
ls -sh ./logs
```
```py
total 4.0K
4.0K train
```
## 导出到 SavedModel
将图形和变量导出为与平台无关的 SavedModel 格式。 保存模型后,可以在有或没有 scope 的情况下加载模型。
```py
path = 'saved_model/'
```
```py
model.save(path, save_format='tf')
```
```py
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: saved_model/assets
INFO:tensorflow:Assets written to: saved_model/assets
```
在无需 `strategy.scope` 加载模型。
```py
unreplicated_model = tf.keras.models.load_model(path)
unreplicated_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)
print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
```
```py
157/157 [==============================] - 1s 3ms/step - loss: 0.0399 - accuracy: 0.9861
Eval loss: 0.03988004848361015, Eval Accuracy: 0.9861000180244446
```
在含 `strategy.scope` 加载模型。
```py
with strategy.scope():
replicated_model = tf.keras.models.load_model(path)
replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
```
```py
157/157 [==============================] - 1s 5ms/step - loss: 0.0399 - accuracy: 0.9861
Eval loss: 0.03988004848361015, Eval Accuracy: 0.9861000180244446
```
### 示例和教程
以下是使用 keras fit/compile 分布式策略的一些示例:
1. 使用[`tf.distribute.MirroredStrategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/MirroredStrategy) 训练 [Transformer](https://github.com/tensorflow/models/blob/master/official/nlp/transformer/transformer_main.py) 的示例。
2. 使用[`tf.distribute.MirroredStrategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/MirroredStrategy) 训练 [NCF](https://github.com/tensorflow/models/blob/master/official/recommendation/ncf_keras_main.py) 的示例。
[分布式策略指南](https://tensorflow.google.cn/guide/distribute_strategy#examples_and_tutorials)中列出的更多示例
## 下一步
* 阅读[分布式策略指南](https://tensorflow.google.cn/guide/distribute_strategy)。
* 阅读[自定义训练的分布式训练](/tutorials/distribute/training_loops)教程。
注意:[`tf.distribute.Strategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/Strategy) 正在积极开发中,我们将在不久的将来添加更多示例和教程。欢迎您进行尝试。我们欢迎您通过 [GitHub 上的 issue](https://github.com/tensorflow/tensorflow/issues/new) 提供反馈。