我有一个形状为 500,36,24,72 的 NumPy 数组。现在我想使用 tf.data 为问题创建数据管道。对于每次迭代,仅需要数组的子集,例如,首先在 [500,x:y,24,72] 上训练模型,其中仅采用第二维的子集。
ds1 = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(data))
对上述数据集应用过滤器似乎不起作用
ds2 = ds1.filter(lambda x: x[1:3][:][:])
最佳答案
import numpy as np
import tensorflow as tf
data = np.random.random((500,36,24,72))
ds1 = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(data)))
ds2 = ds1.map(lambda x: x[1:3, ...])
关于python - 从 tf.data 中仅提取 numpy 数组的一部分,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/73685542/