python - 如何检查 tf.estimator.inputs.numpy_input_fn 的内容?

标签 python tensorflow

我想在一组数据上重复训练我的 tensorflow 图,我想 tf.estimator.inputs.numpy_input_fn可能就是我正在寻找的。我发现批量大小、重复、纪元和迭代器之间的区别非常令人困惑,因此我开始尝试检查数据集的内容,试图找出实际发生的情况。然而,每当我尝试这样做时,我的程序就会挂起。

这是我想出的重现此问题的最小测试用例:

import tensorflow as tf
import numpy

class TestMock(tf.test.TestCase):
    def test(self):
        inputs = numpy.array(range(10))
        targets = numpy.array(range(10,20))

        input_fn = tf.estimator.inputs.numpy_input_fn(
            x=inputs,
            y=targets,
            batch_size=1,
            num_epochs=2,
            shuffle=False)

        print input_fn()
        with self.test_session() as sess:
            # sess.run(input_fn()[0]) # it'll hang if I run this
            pass

if __name__ == '__main__':
    tf.test.main()

该程序输出

(<tf.Tensor 'fifo_queue_DequeueUpTo:1' shape=(?,) dtype=int64>, <tf.Tensor 'fifo_queue_DequeueUpTo:2' shape=(?,) dtype=int64>)

这看起来很合理,但是当我尝试运行 sess.run 行时,我的程序就会卡住,我必须终止该进程。我在这里做错了什么?

我想要做的是确保我输入到流程中的数据实际上是我认为的那样,但我认为如果没有检查数据的能力我就无法做到这一点。

最佳答案

从上面的打印语句我们可以推断出input_fn返回queue ops,我们需要使用start_queue_runners and Coordinator来运行它们:

 features_op, labels_op = input_fn()
 with tf.Session() as sess:
     # initialise and start the queues.
     sess.run(tf.local_variables_initializer())

     coordinator = tf.train.Coordinator()
     _ = tf.train.start_queue_runners(coord=coordinator)

    print(sess.run([features_op, labels_op]))

    #[array([0]), array([10])]

关于python - 如何检查 tf.estimator.inputs.numpy_input_fn 的内容?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50789693/

相关文章:

python - Pandas 数据框如何用多个替换单列

python - Django i18n 不工作

python - 按 pandas 将数据与行数分组

python - Keras 自定义损失函数,包含来自完整输入数据集的样本

collections - 在tensorflow中,你可以定义自己的集合名称吗?

tensorflow - 如何在使用 Keras 或 Tensorflow 训练深度神经网络期间添加更多数据

python - 从检查点恢复时,如何更改参数的数据类型?

python - 需要帮助创建 GAE 数据存储加载器类?

python - 如何控制Scapy发送数据包的速度?

python - 查找数据框中每组另一个常见单元格中最常见的单元格