python - 当我使用数据集,连接后,dataset.map仅作用于原始数据

标签 python python-3.x tensorflow tensorflow-datasets

正如标题所示,我连接两个数据集并使用映射函数来更改值的位置和重新缩放值。在我使用map之前,所有张量的形状都是 匹配,但是使用map函数后,并使用for循环迭代数据集打印索引,迭代的断点在两个数据集的连接处。

我在使用 GPU 的 Colab 中遇到了这个问题, 并使用Python 3.6,tensorflow-gpu 2.0.0b1

dataset_crop = tf.data.Dataset.from_generator(img_resize_and_crop_genr, (tf.float32, tf.float32),((7,), (48,48,1)))
dataset = dataset.concatenate(dataset_crop)
dataset = dataset.map(lambda label, img_raw: (tf.cast(img_raw, tf.float32)/float(255), label))
for i,(label, img) in enumerate(dataset):
  print(i)

顺便说一句,连接之前数据集的总行数为 19984
连接到底是什么鬼..

...
19982
19983
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-26-36305ee0e8ef> in <module>()
----> 1 for i,(label, img) in enumerate(dataset):
      2   print(i)

4 frames
/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: ValueError: Tensor's shape (7,) is not compatible with supplied shape [48, 48, 1]
Traceback (most recent call last):

  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/script_ops.py", line 209, in __call__
    ret = func(*args)

  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 525, in generator_py_func
    values = next(generator_state.get_iterator(iterator_id))

  File "<ipython-input-25-196a9ac04fc0>", line 5, in img_resize_and_crop_genr
    img.set_shape([side_len, side_len,1])

  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py", line 981, in set_shape
    (self.shape, shape))

ValueError: Tensor's shape (7,) is not compatible with supplied shape [48, 48, 1]


     [[{{node PyFunc}}]] [Op:IteratorGetNextSync]

最佳答案

问题出在您的 from_generator 函数中。当您传递 output_shapes 参数时,会进行严格检查以查看输出形状是否与生成的形状完全相同。在您的情况下,您会收到一个 ValueError ,表明它期望 (48, 48, 1) 但已生成 (7,) 形状。

使用以下代码可以生成类似的错误:

dataset = tf.data.Dataset.from_tensor_slices((np.zeros(19984, dtype=np.float32), np.ones(19984, dtype=np.float32)))

def img_resize_and_crop_genr():
    yield np.zeros((7,)), np.ones((48, 48, 1))

dataset_crop = tf.data.Dataset.from_generator(img_resize_and_crop_genr, (tf.float32, tf.float32),((48,48,1), (7,)))
dataset = dataset.concatenate(dataset_crop)
dataset = dataset.map(lambda label, img_raw: (tf.cast(img_raw, tf.float32)/float(255), label))
for i,(label, img) in enumerate(dataset):
  print(i)

输出:

ValueError: `generator` yielded an element of shape (7,) where an element of shape (48, 48, 1) was expected.

我相信您已经互换了您的output_shapes。如果是这种情况,您可以进行更正:

dataset_crop = tf.data.Dataset.from_generator(img_resize_and_crop_genr, 
                                              (tf.float32, tf.float32),((7,), (48,48,1)))

此外,output_shapes 是一个可选参数。您可以通过不传递参数来避免整个问题,如下所示:

dataset_crop = tf.data.Dataset.from_generator(img_resize_and_crop_genr, 
                                              (tf.float32, tf.float32))

关于python - 当我使用数据集,连接后,dataset.map仅作用于原始数据,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58013427/

相关文章:

python-3.x - 运行 Jupyter notebook 时获取 RuntimeWarning 并且永远不会连接到内核

python - TensorFlow:使用不同的损失函数恢复训练

tensorflow - 如何使用tensorflow实现反卷积?

tensorflow - 如何用索引替换张量中的值?

python - 为什么我的用户定义的异常没有得到正确处理?

python - 如何将 TF-IDF 分数组合起来相当于连接两个字符串

python - 删除 python 中的换行符?

python - 如何使用屏幕上的键盘将文本插入到 QLineEdit

python - 如何在 Python 中使用 subprocess.check_output()?

python - Pandas - 使用 .isnull()、notnull()、dropna() 删除缺少数据的行不起作用