mirror of
https://github.com/Estom/notes.git
synced 2026-04-05 11:57:37 +08:00
872 lines
22 KiB
Markdown
872 lines
22 KiB
Markdown
# 用 tf.data 加载图片
|
||
|
||
> 原文:[https://tensorflow.google.cn/tutorials/load_data/images](https://tensorflow.google.cn/tutorials/load_data/images)
|
||
|
||
**Note:** 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为, 所以无法保证它们是最准确的,并且反映了最新的 [官方英文文档](https://tensorflow.google.cn/?hl=en)。如果您有改进此翻译的建议, 请提交 pull request 到 [tensorflow/docs](https://github.com/tensorflow/docs) GitHub 仓库。要志愿地撰写或者审核译文,请加入 [docs-zh-cn@tensorflow.org Google Group](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-zh-cn)。
|
||
|
||
本教程提供一个如何使用 [`tf.data`](https://tensorflow.google.cn/api_docs/python/tf/data) 加载图片的简单例子。
|
||
|
||
本例中使用的数据集分布在图片文件夹中,一个文件夹含有一类图片。
|
||
|
||
## 配置
|
||
|
||
```py
|
||
import tensorflow as tf
|
||
```
|
||
|
||
```py
|
||
AUTOTUNE = tf.data.experimental.AUTOTUNE
|
||
```
|
||
|
||
## 下载并检查数据集
|
||
|
||
### 检索图片
|
||
|
||
在你开始任何训练之前,你将需要一组图片来教会网络你想要训练的新类别。你已经创建了一个文件夹,存储了最初使用的拥有创作共用许可的花卉照片。
|
||
|
||
```py
|
||
import pathlib
|
||
data_root_orig = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
|
||
fname='flower_photos', untar=True)
|
||
data_root = pathlib.Path(data_root_orig)
|
||
print(data_root)
|
||
```
|
||
|
||
```py
|
||
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
|
||
228818944/228813984 [==============================] - 2s 0us/step
|
||
/home/kbuilder/.keras/datasets/flower_photos
|
||
|
||
```
|
||
|
||
下载了 218 MB 之后,你现在应该有花卉照片副本:
|
||
|
||
```py
|
||
for item in data_root.iterdir():
|
||
print(item)
|
||
```
|
||
|
||
```py
|
||
/home/kbuilder/.keras/datasets/flower_photos/sunflowers
|
||
/home/kbuilder/.keras/datasets/flower_photos/daisy
|
||
/home/kbuilder/.keras/datasets/flower_photos/LICENSE.txt
|
||
/home/kbuilder/.keras/datasets/flower_photos/roses
|
||
/home/kbuilder/.keras/datasets/flower_photos/tulips
|
||
/home/kbuilder/.keras/datasets/flower_photos/dandelion
|
||
|
||
```
|
||
|
||
```py
|
||
import random
|
||
all_image_paths = list(data_root.glob('*/*'))
|
||
all_image_paths = [str(path) for path in all_image_paths]
|
||
random.shuffle(all_image_paths)
|
||
|
||
image_count = len(all_image_paths)
|
||
image_count
|
||
```
|
||
|
||
```py
|
||
3670
|
||
|
||
```
|
||
|
||
```py
|
||
all_image_paths[:10]
|
||
```
|
||
|
||
```py
|
||
['/home/kbuilder/.keras/datasets/flower_photos/daisy/4820415253_15bc3b6833_n.jpg',
|
||
'/home/kbuilder/.keras/datasets/flower_photos/roses/14172324538_2147808483_n.jpg',
|
||
'/home/kbuilder/.keras/datasets/flower_photos/sunflowers/15054866658_c1a6223403_m.jpg',
|
||
'/home/kbuilder/.keras/datasets/flower_photos/daisy/422094774_28acc69a8b_n.jpg',
|
||
'/home/kbuilder/.keras/datasets/flower_photos/roses/22982871191_ec61e36939_n.jpg',
|
||
'/home/kbuilder/.keras/datasets/flower_photos/tulips/8673416166_620fc18e2f_n.jpg',
|
||
'/home/kbuilder/.keras/datasets/flower_photos/tulips/16582481123_06e8e6b966_n.jpg',
|
||
'/home/kbuilder/.keras/datasets/flower_photos/daisy/5434914569_e9b982fde0_n.jpg',
|
||
'/home/kbuilder/.keras/datasets/flower_photos/sunflowers/184682652_c927a49226_m.jpg',
|
||
'/home/kbuilder/.keras/datasets/flower_photos/dandelion/3021333497_b927cd8596.jpg']
|
||
|
||
```
|
||
|
||
### 检查图片
|
||
|
||
现在让我们快速浏览几张图片,这样你知道你在处理什么:
|
||
|
||
```py
|
||
import os
|
||
attributions = (data_root/"LICENSE.txt").open(encoding='utf-8').readlines()[4:]
|
||
attributions = [line.split(' CC-BY') for line in attributions]
|
||
attributions = dict(attributions)
|
||
```
|
||
|
||
```py
|
||
import IPython.display as display
|
||
|
||
def caption_image(image_path):
|
||
image_rel = pathlib.Path(image_path).relative_to(data_root)
|
||
return "Image (CC BY 2.0) " + ' - '.join(attributions[str(image_rel)].split(' - ')[:-1])
|
||
```
|
||
|
||
```py
|
||
for n in range(3):
|
||
image_path = random.choice(all_image_paths)
|
||
display.display(display.Image(image_path))
|
||
print(caption_image(image_path))
|
||
print()
|
||
```
|
||
|
||

|
||
|
||
```py
|
||
Image (CC BY 2.0) by Pavlina Jane
|
||
|
||
```
|
||
|
||

|
||
|
||
```py
|
||
Image (CC BY 2.0) by Samantha Forsberg
|
||
|
||
```
|
||
|
||

|
||
|
||
```py
|
||
Image (CC BY 2.0) by Manu
|
||
|
||
```
|
||
|
||
### 确定每张图片的标签
|
||
|
||
列出可用的标签:
|
||
|
||
```py
|
||
label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
|
||
label_names
|
||
```
|
||
|
||
```py
|
||
['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
|
||
|
||
```
|
||
|
||
为每个标签分配索引:
|
||
|
||
```py
|
||
label_to_index = dict((name, index) for index, name in enumerate(label_names))
|
||
label_to_index
|
||
```
|
||
|
||
```py
|
||
{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
|
||
|
||
```
|
||
|
||
创建一个列表,包含每个文件的标签索引:
|
||
|
||
```py
|
||
all_image_labels = [label_to_index[pathlib.Path(path).parent.name]
|
||
for path in all_image_paths]
|
||
|
||
print("First 10 labels indices: ", all_image_labels[:10])
|
||
```
|
||
|
||
```py
|
||
First 10 labels indices: [0, 2, 3, 0, 2, 4, 4, 0, 3, 1]
|
||
|
||
```
|
||
|
||
### 加载和格式化图片
|
||
|
||
TensorFlow 包含加载和处理图片时你需要的所有工具:
|
||
|
||
```py
|
||
img_path = all_image_paths[0]
|
||
img_path
|
||
```
|
||
|
||
```py
|
||
'/home/kbuilder/.keras/datasets/flower_photos/daisy/4820415253_15bc3b6833_n.jpg'
|
||
|
||
```
|
||
|
||
以下是原始数据:
|
||
|
||
```py
|
||
img_raw = tf.io.read_file(img_path)
|
||
print(repr(img_raw)[:100]+"...")
|
||
```
|
||
|
||
```py
|
||
<tf.Tensor: shape=(), dtype=string, numpy=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00...
|
||
|
||
```
|
||
|
||
将它解码为图像 tensor(张量):
|
||
|
||
```py
|
||
img_tensor = tf.image.decode_image(img_raw)
|
||
|
||
print(img_tensor.shape)
|
||
print(img_tensor.dtype)
|
||
```
|
||
|
||
```py
|
||
(224, 320, 3)
|
||
<dtype: 'uint8'>
|
||
|
||
```
|
||
|
||
根据你的模型调整其大小:
|
||
|
||
```py
|
||
img_final = tf.image.resize(img_tensor, [192, 192])
|
||
img_final = img_final/255.0
|
||
print(img_final.shape)
|
||
print(img_final.numpy().min())
|
||
print(img_final.numpy().max())
|
||
```
|
||
|
||
```py
|
||
(192, 192, 3)
|
||
0.0
|
||
1.0
|
||
|
||
```
|
||
|
||
将这些包装在一个简单的函数里,以备后用。
|
||
|
||
```py
|
||
def preprocess_image(image):
|
||
image = tf.image.decode_jpeg(image, channels=3)
|
||
image = tf.image.resize(image, [192, 192])
|
||
image /= 255.0 # normalize to [0,1] range
|
||
|
||
return image
|
||
```
|
||
|
||
```py
|
||
def load_and_preprocess_image(path):
|
||
image = tf.io.read_file(path)
|
||
return preprocess_image(image)
|
||
```
|
||
|
||
```py
|
||
import matplotlib.pyplot as plt
|
||
|
||
image_path = all_image_paths[0]
|
||
label = all_image_labels[0]
|
||
|
||
plt.imshow(load_and_preprocess_image(img_path))
|
||
plt.grid(False)
|
||
plt.xlabel(caption_image(img_path))
|
||
plt.title(label_names[label].title())
|
||
print()
|
||
```
|
||
|
||

|
||
|
||
## 构建一个 [`tf.data.Dataset`](https://tensorflow.google.cn/api_docs/python/tf/data/Dataset)
|
||
|
||
### 一个图片数据集
|
||
|
||
构建 [`tf.data.Dataset`](https://tensorflow.google.cn/api_docs/python/tf/data/Dataset) 最简单的方法就是使用 `from_tensor_slices` 方法。
|
||
|
||
将字符串数组切片,得到一个字符串数据集:
|
||
|
||
```py
|
||
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
|
||
```
|
||
|
||
`shapes(维数)` 和 `types(类型)` 描述数据集里每个数据项的内容。在这里是一组标量二进制字符串。
|
||
|
||
```py
|
||
print(path_ds)
|
||
```
|
||
|
||
```py
|
||
<TensorSliceDataset shapes: (), types: tf.string>
|
||
|
||
```
|
||
|
||
现在创建一个新的数据集,通过在路径数据集上映射 `preprocess_image` 来动态加载和格式化图片。
|
||
|
||
```py
|
||
image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
|
||
```
|
||
|
||
```py
|
||
import matplotlib.pyplot as plt
|
||
|
||
plt.figure(figsize=(8,8))
|
||
for n, image in enumerate(image_ds.take(4)):
|
||
plt.subplot(2,2,n+1)
|
||
plt.imshow(image)
|
||
plt.grid(False)
|
||
plt.xticks([])
|
||
plt.yticks([])
|
||
plt.xlabel(caption_image(all_image_paths[n]))
|
||
plt.show()
|
||
```
|
||
|
||

|
||
|
||

|
||
|
||

|
||
|
||

|
||
|
||
### 一个`(图片, 标签)`对数据集
|
||
|
||
使用同样的 `from_tensor_slices` 方法你可以创建一个标签数据集:
|
||
|
||
```py
|
||
label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels, tf.int64))
|
||
```
|
||
|
||
```py
|
||
for label in label_ds.take(10):
|
||
print(label_names[label.numpy()])
|
||
```
|
||
|
||
```py
|
||
daisy
|
||
roses
|
||
sunflowers
|
||
daisy
|
||
roses
|
||
tulips
|
||
tulips
|
||
daisy
|
||
sunflowers
|
||
dandelion
|
||
|
||
```
|
||
|
||
由于这些数据集顺序相同,你可以将他们打包在一起得到一个`(图片, 标签)`对数据集:
|
||
|
||
```py
|
||
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
|
||
```
|
||
|
||
这个新数据集的 `shapes(维数)` 和 `types(类型)` 也是维数和类型的元组,用来描述每个字段:
|
||
|
||
```py
|
||
print(image_label_ds)
|
||
```
|
||
|
||
```py
|
||
<ZipDataset shapes: ((192, 192, 3), ()), types: (tf.float32, tf.int64)>
|
||
|
||
```
|
||
|
||
注意:当你拥有形似 `all_image_labels` 和 `all_image_paths` 的数组,`tf.data.dataset.Dataset.zip` 的替代方法是将这对数组切片。
|
||
|
||
```py
|
||
ds = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))
|
||
|
||
# 元组被解压缩到映射函数的位置参数中
|
||
def load_and_preprocess_from_path_label(path, label):
|
||
return load_and_preprocess_image(path), label
|
||
|
||
image_label_ds = ds.map(load_and_preprocess_from_path_label)
|
||
image_label_ds
|
||
```
|
||
|
||
```py
|
||
<MapDataset shapes: ((192, 192, 3), ()), types: (tf.float32, tf.int32)>
|
||
|
||
```
|
||
|
||
### 训练的基本方法
|
||
|
||
要使用此数据集训练模型,你将会想要数据:
|
||
|
||
* 被充分打乱。
|
||
* 被分割为 batch。
|
||
* 永远重复。
|
||
* 尽快提供 batch。
|
||
|
||
使用 [`tf.data`](https://tensorflow.google.cn/api_docs/python/tf/data) api 可以轻松添加这些功能。
|
||
|
||
```py
|
||
BATCH_SIZE = 32
|
||
|
||
# 设置一个和数据集大小一致的 shuffle buffer size(随机缓冲区大小)以保证数据
|
||
# 被充分打乱。
|
||
ds = image_label_ds.shuffle(buffer_size=image_count)
|
||
ds = ds.repeat()
|
||
ds = ds.batch(BATCH_SIZE)
|
||
# 当模型在训练的时候,`prefetch` 使数据集在后台取得 batch。
|
||
ds = ds.prefetch(buffer_size=AUTOTUNE)
|
||
ds
|
||
```
|
||
|
||
```py
|
||
<PrefetchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int32)>
|
||
|
||
```
|
||
|
||
这里有一些注意事项:
|
||
|
||
1. 顺序很重要。
|
||
|
||
* 在 `.repeat` 之后 `.shuffle`,会在 epoch 之间打乱数据(当有些数据出现两次的时候,其他数据还没有出现过)。
|
||
|
||
* 在 `.batch` 之后 `.shuffle`,会打乱 batch 的顺序,但是不会在 batch 之间打乱数据。
|
||
|
||
2. 你在完全打乱中使用和数据集大小一样的 `buffer_size(缓冲区大小)`。较大的缓冲区大小提供更好的随机化,但使用更多的内存,直到超过数据集大小。
|
||
|
||
3. 在从随机缓冲区中拉取任何元素前,要先填满它。所以当你的 `Dataset(数据集)`启动的时候一个大的 `buffer_size(缓冲区大小)`可能会引起延迟。
|
||
|
||
4. 在随机缓冲区完全为空之前,被打乱的数据集不会报告数据集的结尾。`Dataset(数据集)`由 `.repeat` 重新启动,导致需要再次等待随机缓冲区被填满。
|
||
|
||
最后一点可以通过使用 [`tf.data.Dataset.apply`](https://tensorflow.google.cn/api_docs/python/tf/data/Dataset#apply) 方法和融合过的 [`tf.data.experimental.shuffle_and_repeat`](https://tensorflow.google.cn/api_docs/python/tf/data/experimental/shuffle_and_repeat) 函数来解决:
|
||
|
||
```py
|
||
ds = image_label_ds.apply(
|
||
tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
|
||
ds = ds.batch(BATCH_SIZE)
|
||
ds = ds.prefetch(buffer_size=AUTOTUNE)
|
||
ds
|
||
```
|
||
|
||
```py
|
||
WARNING:tensorflow:From <ipython-input-1-4dc713bd4d84>:2: shuffle_and_repeat (from tensorflow.python.data.experimental.ops.shuffle_ops) is deprecated and will be removed in a future version.
|
||
Instructions for updating:
|
||
Use `tf.data.Dataset.shuffle(buffer_size, seed)` followed by `tf.data.Dataset.repeat(count)`. Static tf.data optimizations will take care of using the fused implementation.
|
||
|
||
<PrefetchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int32)>
|
||
|
||
```
|
||
|
||
### 传递数据集至模型
|
||
|
||
从 [`tf.keras.applications`](https://tensorflow.google.cn/api_docs/python/tf/keras/applications) 取得 MobileNet v2 副本。
|
||
|
||
该模型副本会被用于一个简单的迁移学习例子。
|
||
|
||
设置 MobileNet 的权重为不可训练:
|
||
|
||
```py
|
||
mobile_net = tf.keras.applications.MobileNetV2(input_shape=(192, 192, 3), include_top=False)
|
||
mobile_net.trainable=False
|
||
```
|
||
|
||
```py
|
||
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_192_no_top.h5
|
||
9412608/9406464 [==============================] - 0s 0us/step
|
||
|
||
```
|
||
|
||
该模型期望它的输出被标准化至 `[-1,1]` 范围内:
|
||
|
||
```py
|
||
help(keras_applications.mobilenet_v2.preprocess_input)
|
||
```
|
||
|
||
```py
|
||
……
|
||
该函数使用“Inception”预处理,将
|
||
RGB 值从 [0, 255] 转化为 [-1, 1]
|
||
……
|
||
|
||
```
|
||
|
||
在你将输出传递给 MobilNet 模型之前,你需要将其范围从 `[0,1]` 转化为 `[-1,1]`:
|
||
|
||
```py
|
||
def change_range(image,label):
|
||
return 2*image-1, label
|
||
|
||
keras_ds = ds.map(change_range)
|
||
```
|
||
|
||
MobileNet 为每张图片的特征返回一个 `6x6` 的空间网格。
|
||
|
||
传递一个 batch 的图片给它,查看结果:
|
||
|
||
```py
|
||
# 数据集可能需要几秒来启动,因为要填满其随机缓冲区。
|
||
image_batch, label_batch = next(iter(keras_ds))
|
||
```
|
||
|
||
```py
|
||
feature_map_batch = mobile_net(image_batch)
|
||
print(feature_map_batch.shape)
|
||
```
|
||
|
||
```py
|
||
(32, 6, 6, 1280)
|
||
|
||
```
|
||
|
||
构建一个包装了 MobileNet 的模型并在 [`tf.keras.layers.Dense`](https://tensorflow.google.cn/api_docs/python/tf/keras/layers/Dense) 输出层之前使用 [`tf.keras.layers.GlobalAveragePooling2D`](https://tensorflow.google.cn/api_docs/python/tf/keras/layers/GlobalAveragePooling2D) 来平均那些空间向量:
|
||
|
||
```py
|
||
model = tf.keras.Sequential([
|
||
mobile_net,
|
||
tf.keras.layers.GlobalAveragePooling2D(),
|
||
tf.keras.layers.Dense(len(label_names), activation = 'softmax')])
|
||
```
|
||
|
||
现在它产出符合预期 shape(维数)的输出:
|
||
|
||
```py
|
||
logit_batch = model(image_batch).numpy()
|
||
|
||
print("min logit:", logit_batch.min())
|
||
print("max logit:", logit_batch.max())
|
||
print()
|
||
|
||
print("Shape:", logit_batch.shape)
|
||
```
|
||
|
||
```py
|
||
min logit: 0.0039403443
|
||
max logit: 0.82328725
|
||
|
||
Shape: (32, 5)
|
||
|
||
```
|
||
|
||
编译模型以描述训练过程:
|
||
|
||
```py
|
||
model.compile(optimizer=tf.keras.optimizers.Adam(),
|
||
loss='sparse_categorical_crossentropy',
|
||
metrics=["accuracy"])
|
||
```
|
||
|
||
此处有两个可训练的变量 —— Dense 层中的 `weights(权重)` 和 `bias(偏差)`:
|
||
|
||
```py
|
||
len(model.trainable_variables)
|
||
```
|
||
|
||
```py
|
||
2
|
||
|
||
```
|
||
|
||
```py
|
||
model.summary()
|
||
```
|
||
|
||
```py
|
||
Model: "sequential"
|
||
_________________________________________________________________
|
||
Layer (type) Output Shape Param #
|
||
=================================================================
|
||
mobilenetv2_1.00_192 (Functi (None, 6, 6, 1280) 2257984
|
||
_________________________________________________________________
|
||
global_average_pooling2d (Gl (None, 1280) 0
|
||
_________________________________________________________________
|
||
dense (Dense) (None, 5) 6405
|
||
=================================================================
|
||
Total params: 2,264,389
|
||
Trainable params: 6,405
|
||
Non-trainable params: 2,257,984
|
||
_________________________________________________________________
|
||
|
||
```
|
||
|
||
你已经准备好来训练模型了。
|
||
|
||
注意,出于演示目的每一个 epoch 中你将只运行 3 step,但一般来说在传递给 `model.fit()` 之前你会指定 step 的真实数量,如下所示:
|
||
|
||
```py
|
||
steps_per_epoch=tf.math.ceil(len(all_image_paths)/BATCH_SIZE).numpy()
|
||
steps_per_epoch
|
||
```
|
||
|
||
```py
|
||
115.0
|
||
|
||
```
|
||
|
||
```py
|
||
model.fit(ds, epochs=1, steps_per_epoch=3)
|
||
```
|
||
|
||
```py
|
||
3/3 [==============================] - 0s 31ms/step - loss: 1.8837 - accuracy: 0.2812
|
||
|
||
<tensorflow.python.keras.callbacks.History at 0x7f43ec118eb8>
|
||
|
||
```
|
||
|
||
## 性能
|
||
|
||
注意:这部分只是展示一些可能帮助提升性能的简单技巧。深入指南,请看:[输入 pipeline(管道)的性能](https://tensorflow.google.cn/guide/performance/datasets)。
|
||
|
||
上面使用的简单 pipeline(管道)在每个 epoch 中单独读取每个文件。在本地使用 CPU 训练时这个方法是可行的,但是可能不足以进行 GPU 训练并且完全不适合任何形式的分布式训练。
|
||
|
||
要研究这点,首先构建一个简单的函数来检查数据集的性能:
|
||
|
||
```py
|
||
import time
|
||
default_timeit_steps = 2*steps_per_epoch+1
|
||
|
||
def timeit(ds, steps=default_timeit_steps):
|
||
overall_start = time.time()
|
||
# 在开始计时之前
|
||
# 取得单个 batch 来填充 pipeline(管道)(填充随机缓冲区)
|
||
it = iter(ds.take(steps+1))
|
||
next(it)
|
||
|
||
start = time.time()
|
||
for i,(images,labels) in enumerate(it):
|
||
if i%10 == 0:
|
||
print('.',end='')
|
||
print()
|
||
end = time.time()
|
||
|
||
duration = end-start
|
||
print("{} batches: {} s".format(steps, duration))
|
||
print("{:0.5f} Images/s".format(BATCH_SIZE*steps/duration))
|
||
print("Total time: {}s".format(end-overall_start))
|
||
```
|
||
|
||
当前数据集的性能是:
|
||
|
||
```py
|
||
ds = image_label_ds.apply(
|
||
tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
|
||
ds = ds.batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
|
||
ds
|
||
```
|
||
|
||
```py
|
||
<PrefetchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int32)>
|
||
|
||
```
|
||
|
||
```py
|
||
timeit(ds)
|
||
```
|
||
|
||
```py
|
||
........................
|
||
231.0 batches: 14.869637966156006 s
|
||
497.12037 Images/s
|
||
Total time: 21.789817333221436s
|
||
|
||
```
|
||
|
||
### 缓存
|
||
|
||
使用 [`tf.data.Dataset.cache`](https://tensorflow.google.cn/api_docs/python/tf/data/Dataset#cache) 在 epoch 之间轻松缓存计算结果。这是非常高效的,特别是当内存能容纳全部数据时。
|
||
|
||
在被预处理之后(解码和调整大小),图片在此被缓存了:
|
||
|
||
```py
|
||
ds = image_label_ds.cache()
|
||
ds = ds.apply(
|
||
tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
|
||
ds = ds.batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
|
||
ds
|
||
```
|
||
|
||
```py
|
||
<PrefetchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int32)>
|
||
|
||
```
|
||
|
||
```py
|
||
timeit(ds)
|
||
```
|
||
|
||
```py
|
||
........................
|
||
231.0 batches: 0.5994970798492432 s
|
||
12330.33529 Images/s
|
||
Total time: 7.475242614746094s
|
||
|
||
```
|
||
|
||
使用内存缓存的一个缺点是必须在每次运行时重建缓存,这使得每次启动数据集时有相同的启动延迟:
|
||
|
||
```py
|
||
timeit(ds)
|
||
```
|
||
|
||
```py
|
||
........................
|
||
231.0 batches: 0.6120779514312744 s
|
||
12076.89312 Images/s
|
||
Total time: 0.6253445148468018s
|
||
|
||
```
|
||
|
||
如果内存不够容纳数据,使用一个缓存文件:
|
||
|
||
```py
|
||
ds = image_label_ds.cache(filename='./cache.tf-data')
|
||
ds = ds.apply(
|
||
tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
|
||
ds = ds.batch(BATCH_SIZE).prefetch(1)
|
||
ds
|
||
```
|
||
|
||
```py
|
||
<PrefetchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int32)>
|
||
|
||
```
|
||
|
||
```py
|
||
timeit(ds)
|
||
```
|
||
|
||
```py
|
||
........................
|
||
231.0 batches: 3.0341720581054688 s
|
||
2436.24945 Images/s
|
||
Total time: 12.044088363647461s
|
||
|
||
```
|
||
|
||
这个缓存文件也有可快速重启数据集而无需重建缓存的优点。注意第二次快了多少:
|
||
|
||
```py
|
||
timeit(ds)
|
||
```
|
||
|
||
```py
|
||
........................
|
||
231.0 batches: 2.358055353164673 s
|
||
3134.78646 Images/s
|
||
Total time: 3.105525493621826s
|
||
|
||
```
|
||
|
||
### TFRecord 文件
|
||
|
||
#### 原始图片数据
|
||
|
||
TFRecord 文件是一种用来存储一串二进制 blob 的简单格式。通过将多个示例打包进同一个文件内,TensorFlow 能够一次性读取多个示例,当使用一个远程存储服务,如 GCS 时,这对性能来说尤其重要。
|
||
|
||
首先,从原始图片数据中构建出一个 TFRecord 文件:
|
||
|
||
```py
|
||
image_ds = tf.data.Dataset.from_tensor_slices(all_image_paths).map(tf.io.read_file)
|
||
tfrec = tf.data.experimental.TFRecordWriter('images.tfrec')
|
||
tfrec.write(image_ds)
|
||
```
|
||
|
||
接着,构建一个从 TFRecord 文件读取的数据集,并使用你之前定义的 `preprocess_image` 函数对图像进行解码/重新格式化:
|
||
|
||
```py
|
||
image_ds = tf.data.TFRecordDataset('images.tfrec').map(preprocess_image)
|
||
```
|
||
|
||
压缩该数据集和你之前定义的标签数据集以得到期望的 `(图片,标签)` 对:
|
||
|
||
```py
|
||
ds = tf.data.Dataset.zip((image_ds, label_ds))
|
||
ds = ds.apply(
|
||
tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
|
||
ds=ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)
|
||
ds
|
||
```
|
||
|
||
```py
|
||
<PrefetchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int64)>
|
||
|
||
```
|
||
|
||
```py
|
||
timeit(ds)
|
||
```
|
||
|
||
```py
|
||
........................
|
||
231.0 batches: 14.661343574523926 s
|
||
504.18299 Images/s
|
||
Total time: 21.57948637008667s
|
||
|
||
```
|
||
|
||
这比 `缓存` 版本慢,因为你还没有缓存预处理。
|
||
|
||
#### 序列化的 Tensor(张量)
|
||
|
||
要为 TFRecord 文件省去一些预处理过程,首先像之前一样制作一个处理过的图片数据集:
|
||
|
||
```py
|
||
paths_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
|
||
image_ds = paths_ds.map(load_and_preprocess_image)
|
||
image_ds
|
||
```
|
||
|
||
```py
|
||
<MapDataset shapes: (192, 192, 3), types: tf.float32>
|
||
|
||
```
|
||
|
||
现在你有一个 tensor(张量)数据集,而不是一个 `.jpeg` 字符串数据集。
|
||
|
||
要将此序列化至一个 TFRecord 文件你首先将该 tensor(张量)数据集转化为一个字符串数据集:
|
||
|
||
```py
|
||
ds = image_ds.map(tf.io.serialize_tensor)
|
||
ds
|
||
```
|
||
|
||
```py
|
||
<MapDataset shapes: (), types: tf.string>
|
||
|
||
```
|
||
|
||
```py
|
||
tfrec = tf.data.experimental.TFRecordWriter('images.tfrec')
|
||
tfrec.write(ds)
|
||
```
|
||
|
||
有了被缓存的预处理,就能从 TFrecord 文件高效地加载数据——只需记得在使用它之前反序列化:
|
||
|
||
```py
|
||
ds = tf.data.TFRecordDataset('images.tfrec')
|
||
|
||
def parse(x):
|
||
result = tf.io.parse_tensor(x, out_type=tf.float32)
|
||
result = tf.reshape(result, [192, 192, 3])
|
||
return result
|
||
|
||
ds = ds.map(parse, num_parallel_calls=AUTOTUNE)
|
||
ds
|
||
```
|
||
|
||
```py
|
||
<ParallelMapDataset shapes: (192, 192, 3), types: tf.float32>
|
||
|
||
```
|
||
|
||
现在,像之前一样添加标签和进行相同的标准操作:
|
||
|
||
```py
|
||
ds = tf.data.Dataset.zip((ds, label_ds))
|
||
ds = ds.apply(
|
||
tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
|
||
ds=ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)
|
||
ds
|
||
```
|
||
|
||
```py
|
||
<PrefetchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int64)>
|
||
|
||
```
|
||
|
||
```py
|
||
timeit(ds)
|
||
```
|
||
|
||
```py
|
||
........................
|
||
231.0 batches: 1.8890972137451172 s
|
||
3912.98020 Images/s
|
||
Total time: 2.7021732330322266s
|
||
|
||
``` |