我想将 tf.Strings
列表传递给 .map(_parse_function)
函数。
def _parse_function(self, img_path):
img_str = tf.read_file(img_path)
img_decode = tf.image.decode_jpeg(img_str, channels=3)
img_decode = tf.divide(tf.cast(img_decode , tf.float32),255)
return img_decode
当tf.data.Dataset
的类型为TensorSliceDataset
时,
dataset_from_slices = tf.data.Dataset.from_tensor_slices((tensor_with_filenames))
我可以简单地做
dataset_from_slices.map(_parse_function
),有效。
但是,dataset_from_generator = tf.data.Dataset.from_generator(...)
返回一个 Dataset
,它是 FlatMapDataset
类型的实例并且 dataset_from_generator.map(_parse_function)
给出以下错误:
InvalidArgumentError: Input filename tensor must be scalar, but had shape: [32]
如果我将第一行更改为:
img_str = tf.read_file(img_path[0])
这也有效,但我只得到第一张图像,这不是我要找的。有什么建议吗?
最佳答案
听起来您的 dataset_from_generator
的元素是批处理的。最简单的补救措施是使用 tf.contrib.data.unbatch() 将它们转换回单个元素:
# Each element is a vector of strings.
dataset_from_generator = tf.data.Dataset.from_generator(...)
# Converts each vector of strings into multiple individual elements.
dataset = dataset_from_generator.apply(tf.contrib.data.unbatch())
dataset = dataset.map(_parse_function)
关于python - 将 tensorflow tf.data.Dataset FlatMapDataset 转换为 TensorSliceDataset,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49890820/