tensorflow - 如何使用自定义函数在 TF 2 中使用 tf.data.Dataset.interleave()?

标签 tensorflow tensorflow2.0 tensorflow-datasets

我正在使用 TF 2.2 并尝试使用 tf.data 创建管道。

以下工作正常:

def load_image(filePath, label):

    print('Loading File: {}' + filePath)
    raw_bytes = tf.io.read_file(filePath)
    image = tf.io.decode_image(raw_bytes, expand_animations = False)

    return image, label

# TrainDS Pipeline
trainDS = getDataset()
trainDS = trainDS.shuffle(size['train'])
trainDS = trainDS.map(load_image, num_parallel_calls=AUTOTUNE)

for d in trainDS:
    print('Image: {} - Label: {}'.format(d[0], d[1]))

我想将 load_image()Dataset.interleave() 一起使用。然后我尝试了:

# TrainDS Pipeline
trainDS = getDataset()
trainDS = trainDS.shuffle(size['train'])
trainDS = trainDS.interleave(lambda x, y: load_image_with_label(x, y), cycle_length=4)

for d in trainDS:
    print('Image: {} - Label: {}'.format(d[0], d[1]))

但我收到以下错误:

Exception has occurred: TypeError
`map_func` must return a `Dataset` object. Got <class 'tuple'>
  File "/data/dev/train_daninhas.py", line 44, in <module>
    trainDS = trainDS.interleave(lambda x, y: load_image_with_label(x, y), cycle_length=4)

如何调整我的代码,使 Dataset.interleave()load_image() 一起并行读取图像?

最佳答案

正如错误提示的那样,您需要修改 load_image 以便它返回一个 Dataset 对象,我已经展示了一个包含两个图像的示例,说明如何去做它在 tensorflow 2.2.0 中:

import tensorflow as tf
filenames = ["./img1.jpg", "./img2.jpg"]
labels = ["A", "B"]

def load_image(filePath, label):
    print('Loading File: {}' + filePath)
    raw_bytes = tf.io.read_file(filePath)
    image = tf.io.decode_image(raw_bytes, expand_animations = False)
    return tf.data.Dataset.from_tensors((image, label))

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.interleave(lambda x, y: load_image(x, y), cycle_length=4)

for i in dataset.as_numpy_iterator():
    image = i[0]
    label = i[1]
    print(image.shape)
    print(label.decode())

# (275, 183, 3)
# A
# (275, 183, 3)
# B

关于tensorflow - 如何使用自定义函数在 TF 2 中使用 tf.data.Dataset.interleave()?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62105896/

相关文章:

python - 为什么要在神经网络中添加零偏差?

time-series - 为什么要采用 HuggingFace 的第一个隐藏状态进行序列分类(DistilBertForSequenceClassification)

python - Tensorflow 2.0 list_physical_devices 未检测到我的 GPU

python - 属性错误 : module 'keras.backend' has no attribute 'image_dim_ordering'

tensorflow 重映射函数

python-3.x - 如何在 tensorflow 2.0b 中检查/释放 GPU 内存?

python - 如何在 tensorflow 数据集中加载 numpy 数组

python - tensorflow 估计器精度和损失为零

python - 使用多个 Excel 工作表加速 pandas 迭代

python - Firebase Tensorflow Lite 分类模型未在 Swift 应用程序中提供正确的输出