mirror of
https://github.com/Estom/notes.git
synced 2026-04-10 22:38:18 +08:00
427 lines
19 KiB
Markdown
427 lines
19 KiB
Markdown
# 利用 Estimator 进行多工作器训练
|
||
|
||
> 原文:[https://tensorflow.google.cn/tutorials/distribute/multi_worker_with_estimator](https://tensorflow.google.cn/tutorials/distribute/multi_worker_with_estimator)
|
||
|
||
**Note:** 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为, 所以无法保证它们是最准确的,并且反映了最新的 [官方英文文档](https://tensorflow.google.cn/?hl=en)。如果您有改进此翻译的建议, 请提交 pull request 到 [tensorflow/docs](https://github.com/tensorflow/docs) GitHub 仓库。要志愿地撰写或者审核译文,请加入 [docs-zh-cn@tensorflow.org Google Group](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-zh-cn)。
|
||
|
||
## 概述
|
||
|
||
本教程展示了在训练分布式多工作器(worker)时,如何使用 [`tf.distribute.Strategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/Strategy)。如果你的代码使用了 [`tf.estimator`](https://tensorflow.google.cn/api_docs/python/tf/estimator),而且你也对拓展单机以获取高性能有兴趣,那么这个教程就是为你准备的。
|
||
|
||
在开始之前,请先阅读 [`tf.distribute.Strategy` 指南](https://tensorflow.google.cn/guide/distribute_strategy)。同样相关的还有 [使用多 GPU 训练教程](https://tensorflow.google.cn/tutorials/distribute/keras),因为在这个教程里也使用了相同的模型。
|
||
|
||
## 创建
|
||
|
||
首先,设置好 TensorFlow 以及将会用到的输入模块。
|
||
|
||
```py
|
||
import tensorflow_datasets as tfds
|
||
import tensorflow as tf
|
||
tfds.disable_progress_bar()
|
||
|
||
import os, json
|
||
```
|
||
|
||
## 输入函数
|
||
|
||
本教程里我们使用的是 [TensorFlow 数据集(TensorFlow Datasets)](https://tensorflow.google.cn/datasets)里的 MNIST 数据集。本教程里的代码和 [使用多 GPU 训练教程](https://tensorflow.google.cn/tutorials/distribute/keras) 类似,但有一个主要区别:当我们使用 Estimator 进行多工作器训练时,需要根据工作器的数量对数据集进行拆分,以确保模型收敛。输入的数据根据工作器其自身的索引来拆分,因此每个工作器各自负责处理该数据集 `1/num_workers` 个不同部分。
|
||
|
||
```py
|
||
BUFFER_SIZE = 10000
|
||
BATCH_SIZE = 64
|
||
|
||
def input_fn(mode, input_context=None):
|
||
datasets, info = tfds.load(name='mnist',
|
||
with_info=True,
|
||
as_supervised=True)
|
||
mnist_dataset = (datasets['train'] if mode == tf.estimator.ModeKeys.TRAIN else
|
||
datasets['test'])
|
||
|
||
def scale(image, label):
|
||
image = tf.cast(image, tf.float32)
|
||
image /= 255
|
||
return image, label
|
||
|
||
if input_context:
|
||
mnist_dataset = mnist_dataset.shard(input_context.num_input_pipelines,
|
||
input_context.input_pipeline_id)
|
||
return mnist_dataset.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
|
||
```
|
||
|
||
使模型收敛的另一种合理方式是在每个工作器上设置不同的随机种子,然后对数据集进行随机重排。
|
||
|
||
## 多工作器配置
|
||
|
||
本教程主要的不同(区别于[使用多 GPU 训练教程](https://tensorflow.google.cn/tutorials/distribute/keras))在于多工作器的创建。明确集群中每个工作器的配置的标准方式是设置环境变量 `TF_CONFIG` 。
|
||
|
||
`TF_CONFIG` 里包括了两个部分:`cluster` 和 `task`。`cluster` 提供了关于整个集群的信息,也就是集群中的工作器和参数服务器(parameter server)。`task` 提供了关于当前任务的信息。在本例中,任务的类型(type)是 worker 且该任务的索引(index)是 0。
|
||
|
||
出于演示的目的,本教程展示了怎么将 `TF_CONFIG` 设置成两个本地的工作器。在实践中,你可以在外部的 IP 地址和端口上创建多个工作器,并为每个工作器正确地配置好 `TF_CONFIG` 变量,也就是更改任务的索引。
|
||
|
||
警告:不要在 Colab 里执行以下代码。TensorFlow 的运行程序会试图在指定的 IP 地址和端口创建 gRPC 服务器,这会导致创建失败。
|
||
|
||
```py
|
||
os.environ['TF_CONFIG'] = json.dumps({
|
||
'cluster': {
|
||
'worker': ["localhost:12345", "localhost:23456"]
|
||
},
|
||
'task': {'type': 'worker', 'index': 0}
|
||
})
|
||
```
|
||
|
||
## 定义模型
|
||
|
||
定义训练中用到的层,优化器和损失函数。本教程使用 Keras layers 定义模型,同[使用多 GPU 训练教程](https://tensorflow.google.cn/tutorials/distribute/keras)类似。
|
||
|
||
```py
|
||
LEARNING_RATE = 1e-4
|
||
def model_fn(features, labels, mode):
|
||
model = tf.keras.Sequential([
|
||
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
|
||
tf.keras.layers.MaxPooling2D(),
|
||
tf.keras.layers.Flatten(),
|
||
tf.keras.layers.Dense(64, activation='relu'),
|
||
tf.keras.layers.Dense(10, activation='softmax')
|
||
])
|
||
logits = model(features, training=False)
|
||
|
||
if mode == tf.estimator.ModeKeys.PREDICT:
|
||
predictions = {'logits': logits}
|
||
return tf.estimator.EstimatorSpec(labels=labels, predictions=predictions)
|
||
|
||
optimizer = tf.compat.v1.train.GradientDescentOptimizer(
|
||
learning_rate=LEARNING_RATE)
|
||
loss = tf.keras.losses.SparseCategoricalCrossentropy(
|
||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(labels, logits)
|
||
loss = tf.reduce_sum(loss) * (1\. / BATCH_SIZE)
|
||
if mode == tf.estimator.ModeKeys.EVAL:
|
||
return tf.estimator.EstimatorSpec(mode, loss=loss)
|
||
|
||
return tf.estimator.EstimatorSpec(
|
||
mode=mode,
|
||
loss=loss,
|
||
train_op=optimizer.minimize(
|
||
loss, tf.compat.v1.train.get_or_create_global_step()))
|
||
```
|
||
|
||
注意:尽管在本例中学习率是固定的,但是通常情况下可能有必要基于全局的批次大小对学习率进行调整。
|
||
|
||
## MultiWorkerMirroredStrategy
|
||
|
||
为训练模型,需要使用 [`tf.distribute.experimental.MultiWorkerMirroredStrategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy) 实例。`MultiWorkerMirroredStrategy` 创建了每个设备中模型层里所有变量的拷贝,且是跨工作器的。其用到了 `CollectiveOps`,这是 TensorFlow 里的一种操作,用来整合梯度以及确保变量同步。该策略的更多细节可以在 [`tf.distribute.Strategy` 指南](https://tensorflow.google.cn/guide/distribute_strategy)中找到。
|
||
|
||
```py
|
||
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
||
```
|
||
|
||
```py
|
||
INFO:tensorflow:Using MirroredStrategy with devices ('/device:GPU:0',)
|
||
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0',), communication = CollectiveCommunication.AUTO
|
||
|
||
```
|
||
|
||
## 训练和评估模型
|
||
|
||
接下来,在 `RunConfig` 中为 estimator 指明分布式策略,同时通过调用 [`tf.estimator.train_and_evaluate`](https://tensorflow.google.cn/api_docs/python/tf/estimator/train_and_evaluate) 训练和评估模型。本教程只通过指明 `train_distribute` 进行分布式训练。但是也同样也可以通过指明 `eval_distribute` 来进行分布式评估。
|
||
|
||
```py
|
||
config = tf.estimator.RunConfig(train_distribute=strategy)
|
||
|
||
classifier = tf.estimator.Estimator(
|
||
model_fn=model_fn, model_dir='/tmp/multiworker', config=config)
|
||
tf.estimator.train_and_evaluate(
|
||
classifier,
|
||
train_spec=tf.estimator.TrainSpec(input_fn=input_fn),
|
||
eval_spec=tf.estimator.EvalSpec(input_fn=input_fn)
|
||
)
|
||
```
|
||
|
||
```py
|
||
INFO:tensorflow:Initializing RunConfig with distribution strategies.
|
||
INFO:tensorflow:Not using Distribute Coordinator.
|
||
INFO:tensorflow:Using config: {'_model_dir': '/tmp/multiworker', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
|
||
graph_options {
|
||
rewrite_options {
|
||
meta_optimizer_iterations: ONE
|
||
}
|
||
}
|
||
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.collective_all_reduce_strategy.CollectiveAllReduceStrategy object at 0x7f975c17f5f8>, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}
|
||
INFO:tensorflow:Not using Distribute Coordinator.
|
||
INFO:tensorflow:Running training and evaluation locally (non-distributed).
|
||
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
|
||
INFO:tensorflow:The `input_fn` accepts an `input_context` which will be given by DistributionStrategy
|
||
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:339: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
|
||
Instructions for updating:
|
||
Use `tf.data.Iterator.get_next_as_optional()` instead.
|
||
|
||
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:339: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
|
||
Instructions for updating:
|
||
Use `tf.data.Iterator.get_next_as_optional()` instead.
|
||
|
||
INFO:tensorflow:Calling model_fn.
|
||
|
||
INFO:tensorflow:Calling model_fn.
|
||
|
||
INFO:tensorflow:Done calling model_fn.
|
||
|
||
INFO:tensorflow:Done calling model_fn.
|
||
|
||
Warning:tensorflow:AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f975c181c80> and will run it as-is.
|
||
Cause: could not parse the source code:
|
||
|
||
lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))
|
||
|
||
This error may be avoided by creating the lambda in a standalone statement.
|
||
|
||
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
|
||
|
||
Warning:tensorflow:AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f975c181c80> and will run it as-is.
|
||
Cause: could not parse the source code:
|
||
|
||
lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))
|
||
|
||
This error may be avoided by creating the lambda in a standalone statement.
|
||
|
||
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
|
||
|
||
Warning: AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f975c181c80> and will run it as-is.
|
||
Cause: could not parse the source code:
|
||
|
||
lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))
|
||
|
||
This error may be avoided by creating the lambda in a standalone statement.
|
||
|
||
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
|
||
INFO:tensorflow:Create CheckpointSaverHook.
|
||
|
||
INFO:tensorflow:Create CheckpointSaverHook.
|
||
|
||
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
|
||
Instructions for updating:
|
||
Use the iterator's `initializer` property instead.
|
||
|
||
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
|
||
Instructions for updating:
|
||
Use the iterator's `initializer` property instead.
|
||
|
||
INFO:tensorflow:Graph was finalized.
|
||
|
||
INFO:tensorflow:Graph was finalized.
|
||
|
||
INFO:tensorflow:Running local_init_op.
|
||
|
||
INFO:tensorflow:Running local_init_op.
|
||
|
||
INFO:tensorflow:Done running local_init_op.
|
||
|
||
INFO:tensorflow:Done running local_init_op.
|
||
|
||
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
|
||
|
||
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
|
||
|
||
INFO:tensorflow:Saving checkpoints for 0 into /tmp/multiworker/model.ckpt.
|
||
|
||
INFO:tensorflow:Saving checkpoints for 0 into /tmp/multiworker/model.ckpt.
|
||
|
||
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
|
||
|
||
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
|
||
|
||
INFO:tensorflow:loss = 2.3033497, step = 0
|
||
|
||
INFO:tensorflow:loss = 2.3033497, step = 0
|
||
|
||
INFO:tensorflow:global_step/sec: 195.373
|
||
|
||
INFO:tensorflow:global_step/sec: 195.373
|
||
|
||
INFO:tensorflow:loss = 2.3039753, step = 100 (0.514 sec)
|
||
|
||
INFO:tensorflow:loss = 2.3039753, step = 100 (0.514 sec)
|
||
|
||
INFO:tensorflow:global_step/sec: 214.711
|
||
|
||
INFO:tensorflow:global_step/sec: 214.711
|
||
|
||
INFO:tensorflow:loss = 2.3031363, step = 200 (0.465 sec)
|
||
|
||
INFO:tensorflow:loss = 2.3031363, step = 200 (0.465 sec)
|
||
|
||
INFO:tensorflow:global_step/sec: 217.488
|
||
|
||
INFO:tensorflow:global_step/sec: 217.488
|
||
|
||
INFO:tensorflow:loss = 2.3034592, step = 300 (0.460 sec)
|
||
|
||
INFO:tensorflow:loss = 2.3034592, step = 300 (0.460 sec)
|
||
|
||
INFO:tensorflow:global_step/sec: 218.917
|
||
|
||
INFO:tensorflow:global_step/sec: 218.917
|
||
|
||
INFO:tensorflow:loss = 2.3013198, step = 400 (0.457 sec)
|
||
|
||
INFO:tensorflow:loss = 2.3013198, step = 400 (0.457 sec)
|
||
|
||
INFO:tensorflow:global_step/sec: 219.726
|
||
|
||
INFO:tensorflow:global_step/sec: 219.726
|
||
|
||
INFO:tensorflow:loss = 2.3037362, step = 500 (0.455 sec)
|
||
|
||
INFO:tensorflow:loss = 2.3037362, step = 500 (0.455 sec)
|
||
|
||
INFO:tensorflow:global_step/sec: 219.401
|
||
|
||
INFO:tensorflow:global_step/sec: 219.401
|
||
|
||
INFO:tensorflow:loss = 2.3062348, step = 600 (0.455 sec)
|
||
|
||
INFO:tensorflow:loss = 2.3062348, step = 600 (0.455 sec)
|
||
|
||
INFO:tensorflow:global_step/sec: 220.068
|
||
|
||
INFO:tensorflow:global_step/sec: 220.068
|
||
|
||
INFO:tensorflow:loss = 2.300187, step = 700 (0.455 sec)
|
||
|
||
INFO:tensorflow:loss = 2.300187, step = 700 (0.455 sec)
|
||
|
||
INFO:tensorflow:global_step/sec: 246.384
|
||
|
||
INFO:tensorflow:global_step/sec: 246.384
|
||
|
||
INFO:tensorflow:loss = 2.30475, step = 800 (0.405 sec)
|
||
|
||
INFO:tensorflow:loss = 2.30475, step = 800 (0.405 sec)
|
||
|
||
INFO:tensorflow:global_step/sec: 587.13
|
||
|
||
INFO:tensorflow:global_step/sec: 587.13
|
||
|
||
INFO:tensorflow:loss = 2.3031988, step = 900 (0.170 sec)
|
||
|
||
INFO:tensorflow:loss = 2.3031988, step = 900 (0.170 sec)
|
||
|
||
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 938...
|
||
|
||
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 938...
|
||
|
||
INFO:tensorflow:Saving checkpoints for 938 into /tmp/multiworker/model.ckpt.
|
||
|
||
INFO:tensorflow:Saving checkpoints for 938 into /tmp/multiworker/model.ckpt.
|
||
|
||
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 938...
|
||
|
||
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 938...
|
||
|
||
INFO:tensorflow:Calling model_fn.
|
||
|
||
INFO:tensorflow:Calling model_fn.
|
||
|
||
INFO:tensorflow:Done calling model_fn.
|
||
|
||
INFO:tensorflow:Done calling model_fn.
|
||
|
||
INFO:tensorflow:Starting evaluation at 2020-09-22T19:53:28Z
|
||
|
||
INFO:tensorflow:Starting evaluation at 2020-09-22T19:53:28Z
|
||
|
||
INFO:tensorflow:Graph was finalized.
|
||
|
||
INFO:tensorflow:Graph was finalized.
|
||
|
||
INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-938
|
||
|
||
INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-938
|
||
|
||
INFO:tensorflow:Running local_init_op.
|
||
|
||
INFO:tensorflow:Running local_init_op.
|
||
|
||
INFO:tensorflow:Done running local_init_op.
|
||
|
||
INFO:tensorflow:Done running local_init_op.
|
||
|
||
INFO:tensorflow:Evaluation [10/100]
|
||
|
||
INFO:tensorflow:Evaluation [10/100]
|
||
|
||
INFO:tensorflow:Evaluation [20/100]
|
||
|
||
INFO:tensorflow:Evaluation [20/100]
|
||
|
||
INFO:tensorflow:Evaluation [30/100]
|
||
|
||
INFO:tensorflow:Evaluation [30/100]
|
||
|
||
INFO:tensorflow:Evaluation [40/100]
|
||
|
||
INFO:tensorflow:Evaluation [40/100]
|
||
|
||
INFO:tensorflow:Evaluation [50/100]
|
||
|
||
INFO:tensorflow:Evaluation [50/100]
|
||
|
||
INFO:tensorflow:Evaluation [60/100]
|
||
|
||
INFO:tensorflow:Evaluation [60/100]
|
||
|
||
INFO:tensorflow:Evaluation [70/100]
|
||
|
||
INFO:tensorflow:Evaluation [70/100]
|
||
|
||
INFO:tensorflow:Evaluation [80/100]
|
||
|
||
INFO:tensorflow:Evaluation [80/100]
|
||
|
||
INFO:tensorflow:Evaluation [90/100]
|
||
|
||
INFO:tensorflow:Evaluation [90/100]
|
||
|
||
INFO:tensorflow:Evaluation [100/100]
|
||
|
||
INFO:tensorflow:Evaluation [100/100]
|
||
|
||
INFO:tensorflow:Inference Time : 0.98988s
|
||
|
||
INFO:tensorflow:Inference Time : 0.98988s
|
||
|
||
INFO:tensorflow:Finished evaluation at 2020-09-22-19:53:29
|
||
|
||
INFO:tensorflow:Finished evaluation at 2020-09-22-19:53:29
|
||
|
||
INFO:tensorflow:Saving dict for global step 938: global_step = 938, loss = 2.3031592
|
||
|
||
INFO:tensorflow:Saving dict for global step 938: global_step = 938, loss = 2.3031592
|
||
|
||
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 938: /tmp/multiworker/model.ckpt-938
|
||
|
||
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 938: /tmp/multiworker/model.ckpt-938
|
||
|
||
INFO:tensorflow:Loss for final step: 1.1519132.
|
||
|
||
INFO:tensorflow:Loss for final step: 1.1519132.
|
||
|
||
({'loss': 2.3031592, 'global_step': 938}, [])
|
||
|
||
```
|
||
|
||
# 优化训练后的模型性能
|
||
|
||
现在你已经有了由 [`tf.distribute.Strategy`](https://tensorflow.google.cn/api_docs/python/tf/distribute/Strategy) 的模型和能支持多工作器的 Estimator。你可以尝试使用下列技巧来优化多工作器训练的性能。
|
||
|
||
* *增加单批次的大小:* 此处的批次大小指的是每个 GPU 上的批次大小。通常来说,最大的批次大小应该适应 GPU 的内存大小。
|
||
* *变量转换:* 尽可能将变量转换成 `tf.float`。官方的 ResNet 模型包括了如何完成的[样例](https://github.com/tensorflow/models/blob/8367cf6dabe11adf7628541706b660821f397dce/official/resnet/resnet_model.py#L466)。
|
||
* *使用集群通信:* `MultiWorkerMirroredStrategy` 提供了好几种[集群通信的实现](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/distribute/cross_device_ops.py).
|
||
* `RING` 实现了基于环状的集群,使用了 gRPC 作为跨主机通讯层。
|
||
* `NCCL` 使用了 [英伟达的 NCCL](https://developer.nvidia.com/nccl) 来实现集群。
|
||
* `AUTO` 将选择延后至运行时。
|
||
|
||
集群实现的最优选择不仅基于 GPU 的数量和种类,也基于集群间的通信网络。想要覆盖自动的选项,需要指明 `MultiWorkerMirroredStrategy` 的构造器里的 `communication` 参数,例如让 `communication=tf.distribute.experimental.CollectiveCommunication.NCCL` 。
|
||
|
||
## 更多的代码示例
|
||
|
||
1. [端到端的示例](https://github.com/tensorflow/ecosystem/tree/master/distribution_strategy)里使用了 Kubernetes 模板。在这个例子里我们一开始使用了 Keras 模型,并使用了 [`tf.keras.estimator.model_to_estimator`](https://tensorflow.google.cn/api_docs/python/tf/keras/estimator/model_to_estimator) API 将其转换成了 Estimator。
|
||
2. 官方的 [ResNet50](https://github.com/tensorflow/models/blob/master/official/resnet/imagenet_main.py) 模型,我们可以使用 `MirroredStrategy` 或 `MultiWorkerMirroredStrategy` 来训练它。 |