python - 数据 API : ValueError: `y` argument is not supported when using dataset as input

标签 python tensorflow tensorflow-datasets

我有 45000 张大小为 224*224 的图像,存储为一个 numpy 数组。这个名为 source_arr 的数组的形状为 45000,224,224,它适合内存。

我将此数组分为训练、测试和验证数组,并使用 tf.data API 对它们进行预处理(规范化并将灰度转换为 3 channel RGB)。

我写了一个预处理函数,例如:

def pre_process(x):
     x_norm = (x - mean_Rot_MIP) / Var_Rot_MIP
     # Stacking along the last dimension to avoid having to move channel axis
     x_norm_3ch = tf.stack((x_norm, x_norm, x_norm), axis=-1)
     return x_norm_3ch

X_train_cases_idx.idx 包含来自 source_arr 的图像索引,它们是训练数据的一部分。

我已经从数据集对象中的 source_arr 中读取了相应的训练图像,例如:

X_train = tf.data.Dataset.from_tensor_slices([source_arr[i] for i in X_train_cases_idx.idx])

然后我在训练图像上应用 pre_process 函数,例如 X_train = X_train.map(pre_process)

这是一个多类分类问题,因此我将标签变量 y 转换为 1 个热编码,如下所示:

lb = LabelBinarizer()
y_train = lb.fit_transform(y_train)

X_train和y_train的长度都是36000

我在 RESNET50 上执行 model.fit 操作,如下所示:

H = model.fit(X_train, y_train, batch_size = BS, validation_data=(X_val, y_val), epochs = NUM_EPOCHS, shuffle =False)

我得到一个错误:

ValueError: `y` argument is not supported when using dataset as input.

我知道我需要将 X_train 和 y_train 作为数据集对象中的元组传递。 我该怎么做?

最佳答案

你有 source_arr 和 y_train 作为 numpy 数组;所以你可以这样做:

data_set = tf.data.Dataset.from_tensor_slices(  (source_arr , y_train) )

如果你有 source_arr 和 y_train 作为 tf.dataset:

data_set = tf.data.Dataset.zip( (source_arr , y_train) )

关于python - 数据 API : ValueError: `y` argument is not supported when using dataset as input,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63456492/

相关文章:

python - Pandas 获取加载到内存中的所有数据帧的列表

python - TensorFlow 优化器是否最小化 API 实现的小批量?

python - 我如何构建动态查询elasticsearch dsl python

machine-learning - TensorFlow 学习率衰减 - 如何正确提供衰减的步数?

python - 从 TFRecordDataset 获取数据集作为 numpy 数组

python - 在 python 多处理队列中推进 tensorflow 数据集迭代器

python - 为什么 lambda 函数的参数是这样传递的?

python - Tensorboard 权重直方图仅最后一层可见变化

python - 优化tensorflow数据集api中的洗牌缓冲区大小

python - sess.run() 多个操作 vs 多个 sess.run()