Tensorflow 2.0 Building a Dataset and Data Preprocessing

Dataset Object

TF Datasets的官网API文档

本文Code

最基础的建立tf.data.Dataset的方法是使用tf.data.Dataset.from_tensor_slices(),适用于数据量较小(能够整个装进内存)的情况。

备注: 当提供多个张量作为输入时,张量的第 0 维大小必须相同,且必须将多个张量作为元组(Tuple,即使用 Python 中的小括号)拼接并作为输入。

Tensorflow提供了开箱即用的tf.data.Datasets数据集集合,官网Tensorflow Datasets导入方式如下:

import tensorflow_datasets as tfds
dataset = tfds.load('mnist', split=tfds.Split.TRAIN)

Data Preprocessing

tf.data.Dataset类提供了多种数据预处理的方法,最常用API如:

  • dataset.map(func):对数据集中的每个元素应用函数func,通常结合tf.io进行读写和解码文件,tf.image进行图像处理,其中的参数num_parallel_calls代表使用CPU的多个核心进行并行化数据预处理
  • dataset.shuffle(buffer_size):将数据集打乱,设置一个缓冲区,将原数据集中抽取数据到缓存区,从缓存区采样数据,采样后的数据用数据集中的后续数据替换
  • dataset.batch(batch_size):将数据集分成批次,对每batch_size个元素,使用tf.stack在第0维进行合并,成为一个元素
  • dataset.repeat(num): 重复数据集num次
  • dataset.reduce(initial_state, reduce_func): 对元素进行迭代操作
  • dataset.take(count): 去特定数量的元素

可以使用tf.data的并行化策略提高训练流程效率,下面是常规的训练流程,在准备数据时,GPU只能空载

使用Dataset.prefetch()方法进行数据预加载后的训练流程,在 GPU 进行训练的同时 CPU 进行数据预加载,提高了训练效率。

通过设置Dataset.map()num_parallel_calls参数实现数据转换的并行化,可以手动设置,也可以设置为tf.data.experimental.AUTOTUNE让TF自动选择和合适的参数。上部分是未并行化的图示,下部分是 2 核并行的图示。

Data Fetching and Using

tf.data.Dataset是一个python的可迭代对象,因此可以使用for循环迭代获取数据:

dataset = tf.data.Dataset.from_tensor_slices((A, B, C, ...))
for a, b, c, ... in dataset:
    # 对张量a, b, c等进行操作,例如送入模型进行训练

也可以使用iter()显式创建一个python迭代器并且使用next()获取下一个元素。

dataset = tf.data.Dataset.from_tensor_slices((A, B, C, ...))
it = iter(dataset)
a_0, b_0, c_0, ... = next(it)
a_1, b_1, c_1, ... = next(it)

Keras 支持使用 tf.data.Dataset 直接作为输入。

model.fit(mnist_dataset, epochs=num_epochs)

由于已经通过Dataset.batch()方法划分了数据集的批次,所以这里无需提供批次的大小。

TFRecord: Dataset Format

TFRecord是TensorFlow中的数据集存储格式。当我们将数据集整理成TFRecord格式后,TensorFlow就可以高效地读取和处理这些数据集,从而帮助我们更高效地进行大规模的模型训练。

TFRecord 可以理解为一系列序列化的tf.train.Example元素所组成的列表文件,而每一个tf.train.Example又由若干个tf.train.Feature的字典组成。形式如下:

# dataset.tfrecords
[
    {   # example 1 (tf.train.Example)
        'feature_1': tf.train.Feature,
        ...
        'feature_k': tf.train.Feature
    },
    ...
    {   # example N (tf.train.Example)
        'feature_1': tf.train.Feature,
        ...
        'feature_k': tf.train.Feature
    }
]

保存TFRecord文件,详见Code

读取TFRecord文件,详见Code

Reference

  1. TensorFlow-2.x-Tutorials
  2. 简单粗暴Tensorflow 2.0

Note: Cover Picture