python - tensorflow : can not create dataset with corresponding label using Dataset API

标签 python tensorflow dataset

我的数据集包含特征标签,例如。

features, labels = (np.random.sample((5,2)), np.random.sample((5,1))) 

这意味着该数据集中有 5 个数据元素(有 5 行,每行都是 2 维特征和 1 维标签)。

我使用 tf.data.Dataset 使用以下代码创建数据集:

import tensorflow as tf
import numpy as np
features, labels = (np.random.sample((5,2)), np.random.sample((5,1))) 
print("feature : \n", features)
print("labels : \n", labels)

dataset = tf.data.Dataset.from_tensor_slices((features, labels))
iter = dataset.make_one_shot_iterator()            
x, y = iter.get_next()                                       
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())   
    print("element:\n", sess.run(x), sess.run(y))

我使用TF1.5,Windows 10。然后我得到结果:

feature :
 [[0.10261779 0.28041519]  # feature0
 [0.91091857 0.95644642]   # feature1
 [0.77542043 0.49631646]   # ...
 [0.33241678 0.28630983]
 [0.39095336 0.76686785]]
labels :
 [[0.54097027]             # label0
 [0.99022349]              # label1
 [0.87510303]              # ...
 [0.07331254]
 [0.10868335]]
element:
 [0.10261779 0.28041519] [0.99022349]

当我创建数据集时,我希望 feature0 [0.10261779 0.28041519] 与 label0 [0.54097027] 相对应。但使用代码,feature0 [0.10261779 0.28041519] 与 label1 [0.99022349] 相对应。顺序错误。我不知道 get_next 实际上是如何工作的。

我想知道是否有任何方法可以使用tensorflow Dataset API按顺序输出特征和标签。

谢谢

最佳答案

问题是,通过分别运行x运行y,您将迭代器推进两次。也就是说:当调用 sess.run(x) 时,返回 features 的第一个元素,并且迭代器前进。然后调用 sess.run(y) 将返回 labels第二个元素,因为 x >y 基于相同的迭代器。如果您再次调用 sess.run(x),它应该返回 features第三元素,依此类推。

我建议您像这样重写代码,例如:

...
next_batch_op = iter.get_next()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    feature_batch, label_batch = sess.run(next_batch_op)
    print("element:\n", feature_batch, label_batch)

这只会运行迭代器一次并让您访问相应的功能/标签。

作为替代方案,我刚刚尝试了以下方法,它似乎有效:

...
x, y = iter.get_next()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("element:\n", sess.run([x, y]))

与您的代码的区别在于,我们在单个 run 调用中一起运行 xy。不过我发现第一个解决方案更清晰。

关于python - tensorflow : can not create dataset with corresponding label using Dataset API,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51824699/

相关文章:

python - 如何检查小写字母是否存在?

python - 具有重叠区域的卫星图像的插值

python - 在 tensorflow 模型中关闭softmax

python - 单 GPU 上的 Tensorflow 2.0 训练模型

tensorflow - 暗网 YOLO 图像大小

python - 使用大型 (+15 gb) CSV 数据集和 Pandas/XGBoost

Python celery socket.error : [Errno 61] Connection refused

c# - 从包含多个表的数据集中找到最小值和最大值

php - 从 PHP 将数据集传递给 JavaScript 验证

machine-learning - 如何标准化数据以输入位于训练数据范围之外的神经网络?