Files
notes_estom/Tensorflow/TensorFlow2.0/027.md
yinkanglong_lab 68c2dbc3ac apacheCNml&dl
2021-03-20 16:02:39 +08:00

383 lines
14 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 通过 Keras 模型创建 Estimator
> 原文:[https://tensorflow.google.cn/tutorials/estimator/keras_model_to_estimator](https://tensorflow.google.cn/tutorials/estimator/keras_model_to_estimator)
## 概述
TensorFlow 完全支持 TensorFlow Estimator可以从新的和现有的 [`tf.keras`](https://tensorflow.google.cn/api_docs/python/tf/keras) 模型创建 Estimator。本教程包含了该过程完整且最为简短的示例。
## 设置
```py
import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds
```
### 创建一个简单的 Keras 模型。
在 Keras 中,需要通过组装*层*来构建*模型*。模型(通常)是由层构成的计算图。最常见的模型类型是一种叠加层:[`tf.keras.Sequential`](https://tensorflow.google.cn/api_docs/python/tf/keras/Sequential) 模型。
构建一个简单的全连接网络(即多层感知器):
```py
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(16, activation='relu', input_shape=(4,)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(3)
])
```
编译模型并获取摘要。
```py
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer='adam')
model.summary()
```
```py
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 16) 80
_________________________________________________________________
dropout (Dropout) (None, 16) 0
_________________________________________________________________
dense_1 (Dense) (None, 3) 51
=================================================================
Total params: 131
Trainable params: 131
Non-trainable params: 0
_________________________________________________________________
```
### 创建输入函数
使用 [Datasets API](https://tensorflow.google.cn/guide/data) 可以扩展到大型数据集或多设备训练。
Estimator 需要控制构建输入流水线的时间和方式。为此,它们需要一个“输入函数”或 `input_fn``Estimator` 将不使用任何参数调用此函数。`input_fn` 必须返回 [`tf.data.Dataset`](https://tensorflow.google.cn/api_docs/python/tf/data/Dataset)。
```py
def input_fn():
split = tfds.Split.TRAIN
dataset = tfds.load('iris', split=split, as_supervised=True)
dataset = dataset.map(lambda features, labels: ({'dense_input':features}, labels))
dataset = dataset.batch(32).repeat()
return dataset
```
测试您的 `input_fn`
```py
for features_batch, labels_batch in input_fn().take(1):
print(features_batch)
print(labels_batch)
```
```py
Downloading and preparing dataset iris/2.0.0 (download: 4.44 KiB, generated: Unknown size, total: 4.44 KiB) to /home/kbuilder/tensorflow_datasets/iris/2.0.0...
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/iris/2.0.0.incompleteQ29ZWS/iris-train.tfrecord
Dataset iris downloaded and prepared to /home/kbuilder/tensorflow_datasets/iris/2.0.0\. Subsequent calls will reuse this data.
{'dense_input': <tf.Tensor: shape=(32, 4), dtype=float32, numpy=
array([[5.1, 3.4, 1.5, 0.2],
[7.7, 3\. , 6.1, 2.3],
[5.7, 2.8, 4.5, 1.3],
[6.8, 3.2, 5.9, 2.3],
[5.2, 3.4, 1.4, 0.2],
[5.6, 2.9, 3.6, 1.3],
[5.5, 2.6, 4.4, 1.2],
[5.5, 2.4, 3.7, 1\. ],
[4.6, 3.4, 1.4, 0.3],
[7.7, 2.8, 6.7, 2\. ],
[7\. , 3.2, 4.7, 1.4],
[4.6, 3.2, 1.4, 0.2],
[6.5, 3\. , 5.2, 2\. ],
[5.5, 4.2, 1.4, 0.2],
[5.4, 3.9, 1.3, 0.4],
[5\. , 3.5, 1.3, 0.3],
[5.1, 3.8, 1.5, 0.3],
[4.8, 3\. , 1.4, 0.1],
[6.5, 3\. , 5.8, 2.2],
[7.6, 3\. , 6.6, 2.1],
[6.7, 3.3, 5.7, 2.1],
[7.9, 3.8, 6.4, 2\. ],
[6.7, 3\. , 5.2, 2.3],
[5.8, 4\. , 1.2, 0.2],
[6.3, 2.5, 5\. , 1.9],
[5\. , 3\. , 1.6, 0.2],
[6.9, 3.1, 5.1, 2.3],
[6.1, 3\. , 4.6, 1.4],
[5.8, 2.7, 4.1, 1\. ],
[5.2, 2.7, 3.9, 1.4],
[6.7, 3\. , 5\. , 1.7],
[5.7, 2.6, 3.5, 1\. ]], dtype=float32)>}
tf.Tensor([0 2 1 2 0 1 1 1 0 2 1 0 2 0 0 0 0 0 2 2 2 2 2 0 2 0 2 1 1 1 1 1], shape=(32,), dtype=int64)
```
### 通过 tf.keras 模型创建 Estimator。
可以使用 [`tf.estimator`](https://tensorflow.google.cn/api_docs/python/tf/estimator) API 来训练 [`tf.keras.Model`](https://tensorflow.google.cn/api_docs/python/tf/keras/Model),方法是使用 [`tf.keras.estimator.model_to_estimator`](https://tensorflow.google.cn/api_docs/python/tf/keras/estimator/model_to_estimator) 将模型转换为 [`tf.estimator.Estimator`](https://tensorflow.google.cn/api_docs/python/tf/estimator/Estimator) 对象。
```py
import tempfile
model_dir = tempfile.mkdtemp()
keras_estimator = tf.keras.estimator.model_to_estimator(
keras_model=model, model_dir=model_dir)
```
```py
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
INFO:tensorflow:Using the Keras model provided.
INFO:tensorflow:Using the Keras model provided.
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py:220: set_learning_phase (from tensorflow.python.keras.backend) is deprecated and will be removed after 2020-10-11.
Instructions for updating:
Simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py:220: set_learning_phase (from tensorflow.python.keras.backend) is deprecated and will be removed after 2020-10-11.
Instructions for updating:
Simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp13998n2j', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
rewrite_options {
meta_optimizer_iterations: ONE
}
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp13998n2j', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
rewrite_options {
meta_optimizer_iterations: ONE
}
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
```
训练和评估 Estimator。
```py
keras_estimator.train(input_fn=input_fn, steps=500)
eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10)
print('Eval result: {}'.format(eval_result))
```
```py
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp13998n2j/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp13998n2j/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting from: /tmp/tmp13998n2j/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting from: /tmp/tmp13998n2j/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-started 4 variables.
INFO:tensorflow:Warm-started 4 variables.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp13998n2j/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp13998n2j/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 1.5731332, step = 0
INFO:tensorflow:loss = 1.5731332, step = 0
INFO:tensorflow:global_step/sec: 444.326
INFO:tensorflow:global_step/sec: 444.326
INFO:tensorflow:loss = 0.79164267, step = 100 (0.227 sec)
INFO:tensorflow:loss = 0.79164267, step = 100 (0.227 sec)
INFO:tensorflow:global_step/sec: 515.459
INFO:tensorflow:global_step/sec: 515.459
INFO:tensorflow:loss = 0.5765847, step = 200 (0.193 sec)
INFO:tensorflow:loss = 0.5765847, step = 200 (0.193 sec)
INFO:tensorflow:global_step/sec: 518.855
INFO:tensorflow:global_step/sec: 518.855
INFO:tensorflow:loss = 0.48571444, step = 300 (0.193 sec)
INFO:tensorflow:loss = 0.48571444, step = 300 (0.193 sec)
INFO:tensorflow:global_step/sec: 527.318
INFO:tensorflow:global_step/sec: 527.318
INFO:tensorflow:loss = 0.3836534, step = 400 (0.190 sec)
INFO:tensorflow:loss = 0.3836534, step = 400 (0.190 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 500...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 500...
INFO:tensorflow:Saving checkpoints for 500 into /tmp/tmp13998n2j/model.ckpt.
INFO:tensorflow:Saving checkpoints for 500 into /tmp/tmp13998n2j/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 500...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 500...
INFO:tensorflow:Loss for final step: 0.46023262.
INFO:tensorflow:Loss for final step: 0.46023262.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_v1.py:2048: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_v1.py:2048: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2020-09-22T19:57:20Z
INFO:tensorflow:Starting evaluation at 2020-09-22T19:57:20Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp13998n2j/model.ckpt-500
INFO:tensorflow:Restoring parameters from /tmp/tmp13998n2j/model.ckpt-500
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.16498s
INFO:tensorflow:Inference Time : 0.16498s
INFO:tensorflow:Finished evaluation at 2020-09-22-19:57:20
INFO:tensorflow:Finished evaluation at 2020-09-22-19:57:20
INFO:tensorflow:Saving dict for global step 500: global_step = 500, loss = 0.33660004
INFO:tensorflow:Saving dict for global step 500: global_step = 500, loss = 0.33660004
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmp/tmp13998n2j/model.ckpt-500
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmp/tmp13998n2j/model.ckpt-500
Eval result: {'loss': 0.33660004, 'global_step': 500}
```