python - 如何从 tensorflow 数据集中提取没有标签的数据

标签 python tensorflow tensorflow-datasets

我有一个名为 train_ds 的 tf 数据集:

directory = 'Data/dataset_train'

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  directory,
  validation_split=0.2,
  subset="training",
    color_mode='grayscale',
  seed=123,
  image_size=(28, 28),
  batch_size=32)

该数据集由 20000 张“假”图像和 20000 张“真实”图像组成,我想从该 tf 数据集中提取 numpy 形式的 X_train 和 y_train,但我只能使用

y_train = np.concatenate([y for x, y in train_ds], axis=0)

我也尝试过这个,但它似乎没有迭代 20000 张图像:

for images, labels in train_ds.take(-1):  
    X_train = images.numpy()
    y_train = labels.numpy()

我真的很想将图像提取到 X_train,将标签提取到 y_train,但我不知道! 对于我所犯的任何错误,我提前表示歉意,并感谢我可以获得的所有帮助:)

最佳答案

如果您没有对数据集应用进一步的转换,它将是一个 BatchDataset。您可以创建两个列表来迭代数据集。我总共有 2936 张图像。

x_train, y_train = [], []

for images, labels in train_ds:
  x_train.append(images.numpy())
  y_train.append(labels.numpy())

np.array(x_train).shape >> (92,)

它正在生成批处理。您可以使用np.concatenate来连接它们。

x_train = np.concatenate(x_train, axis = 0) 
x_train.shape >> (2936,28,28,3)

或者您可以取消批处理数据集并对其进行迭代:

for images, labels in train_ds.unbatch():
  x_train.append(images.numpy())
  y_train.append(labels.numpy())

x_train = np.array(x_train)
x_train.shape >> (2936,28,28,3)

关于python - 如何从 tensorflow 数据集中提取没有标签的数据,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67521759/

相关文章:

python - 来自生成器的数据集一次生成多个元素

python - TensorFlow 数据集 .map() 方法不适用于内置 tf.keras.preprocessing.image 函数

python - pip installtensorflow在尝试安装相关packeg :futures时抛出错误

python - 常用文档字符串列表 :types for pycharm

python - 使用分页模块的 Django-Rest-Framework csv 模块

python - 在 openerp 中验证 arch 字段时, View 架构的 XML 无效

python - Tensorflow - 仅模型预测 6 个类别中的 2 个类别

python - “模块”对象没有属性 'SummaryWriter'

python - 使用Rawpy读取tensorflow的map方法内的原始图像文件

python - Python代码的计时执行速度