python - 将 tensorflow tf.data.Dataset FlatMapDataset 转换为 TensorSliceDataset

标签 python tensorflow vectorization tensorflow-datasets

我想将 tf.Strings 列表传递给 .map(_parse_function) 函数。

 def _parse_function(self, img_path):
        img_str = tf.read_file(img_path)
        img_decode = tf.image.decode_jpeg(img_str, channels=3)
        img_decode = tf.divide(tf.cast(img_decode , tf.float32),255)
        return img_decode

tf.data.Dataset的类型为TensorSliceDataset时,

dataset_from_slices = tf.data.Dataset.from_tensor_slices((tensor_with_filenames))

我可以简单地做 dataset_from_slices.map(_parse_function),有效。

但是,dataset_from_generator = tf.data.Dataset.from_generator(...) 返回一个 Dataset,它是 FlatMapDataset 类型的实例并且 dataset_from_generator.map(_parse_function) 给出以下错误:

InvalidArgumentError: Input filename tensor must be scalar, but had shape: [32]

如果我将第一行更改为:

img_str = tf.read_file(img_path[0])

这也有效,但我只得到第一张图像,这不是我要找的。有什么建议吗?

最佳答案

听起来您的 dataset_from_generator 的元素是批处理的。最简单的补救措施是使用 tf.contrib.data.unbatch() 将它们转换回单个元素:

# Each element is a vector of strings.
dataset_from_generator = tf.data.Dataset.from_generator(...)

# Converts each vector of strings into multiple individual elements.
dataset = dataset_from_generator.apply(tf.contrib.data.unbatch())

dataset = dataset.map(_parse_function)

关于python - 将 tensorflow tf.data.Dataset FlatMapDataset 转换为 TensorSliceDataset,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49890820/

相关文章:

python-3.x - Tensorflow 值错误 : Cannot evaluate tensor using eval

python - 在Python Itertools中获取一组迭代次数中最大值的索引

python - 如何使用 python 3 从 back4app 检索对象?

tensorflow - 在 Keras 中创建 VAE 时,调用方法未实现运行时错误。模型子类化

matlab - 在 MATLAB 中置换矩阵的列

python - 加快滑动窗口平均计算

R - apply 系列中是否有任何函数可以为矩阵的每个条目计算 FUN?

python - Latex 中的\def 错误

python - 使用 Tkinter 显示一系列 PIL 图像

python - 如何在 Keras 中的预训练 InceptionResNetV2 模型的不同层中找到激活的形状 - Tensorflow 2.0