我正在尝试读取用于训练的自定义映射数据集。但是在我使用 py_function 映射数据集之后,我得到了未知的形状,例如:
def process_path(file_path):
label = get_label(file_path)
img = tf.io.read_file(file_path)
img = decode_img(img)
print('image shape:', img.shape) #this print correctly: image shape: (180, 180, 3)
print('label shape:', label.shape) #this print correctly: label shape: ()
return img, label
train_ds = train_ds.map(lambda x: tf.py_function(process_path, [x], (tf.float32, tf.int32)))
print(train_ds)
# this print unknown shape <PrefetchDataset shapes: (<unknown>, <unknown>), types: (tf.float32, tf.int32)>
这将使 model.fit() 失败,所以我想将数据集 reshape 为正确的形状,例如:
<BatchDataset shapes: ((None, 180, 180, 3), (None,)), types: (tf.float32, tf.int32)>
使用:
train_ds = tf.reshape(train_ds, ((None, 180, 180, 3), (None,)))
但这会报错:
ValueError: Attempt to convert a value (<MapDataset shapes: (<unknown>, <unknown>), types: (tf.float32, tf.int32)>) with an unsupported type (<class 'tensorflow.python.data.ops.dataset_ops.MapDataset'>) to a Tensor.
如何在此步骤中正确分配(图像、标签)形状?
最佳答案
这里不需要py_function
。假设您有一个名为 /dogs
的文件夹,其中充满了 jpg
。您可以使用这两个小函数来加载和解码。
如果文件名(例如,'dogs\\dog1.jpg'
)在文件夹 dogs 中,第一个返回 1
和 0
否则。
第二个函数也接受一个文件名并将其转换为介于 0 和 1 之间的 float 。然后,它还调整图片的大小。
如果有任何不清楚的地方,请告诉我。
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from glob2 import glob
os.chdir('c:/users/nicol/pictures')
files = glob('*/*jpg')
def get_label(file_path):
split = tf.strings.split(file_path, sep=os.sep)[0]
equal = tf.equal(split, 'dogs')
cast = tf.cast(equal, tf.int32)
return cast
def process_path(file_path):
label = get_label(file_path)
img = tf.io.read_file(file_path)
img = tf.image.decode_jpeg(img, channels=3)
img = tf.image.convert_image_dtype(img, tf.float32)
img = tf.image.resize(img, size=(180, 180))
return img, label
train_ds = tf.data.Dataset.from_tensor_slices(files).map(process_path)
next(iter(train_ds))
(<tf.Tensor: shape=(180, 180, 3), dtype=float32, numpy=
array([[[1.41176477e-01, 9.41176564e-02, 1.33333340e-01],
[1.41176477e-01, 9.41176564e-02, 1.33333340e-01],
[1.41176477e-01, 9.41176564e-02, 1.33333340e-01],
...,
[2.63300300e-01, 2.76176542e-01, 4.67582583e-01],
[2.46176332e-01, 2.59706050e-01, 4.50785339e-01],
[2.54726082e-01, 2.68909693e-01, 4.59662050e-01]]], dtype=float32)>,
<tf.Tensor: shape=(), dtype=int32, numpy=1>)
get_label
应该返回一个整数,如果不是的话。
关于python - 如何在 py_function 之后 reshape (图像,标签)数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63749852/