python - Tensorflow:对 tf.estimator.inputs.numpy_input_fn 函数进行故障排除

标签 python tensorflow

我正在运行来自 text classification 的一些教程代码

我可以运行脚本并且它有效,但是当我尝试逐行运行它以试图了解每个步骤在做什么时,我在这一步有点困惑:

test_input_fn = tf.estimator.inputs.numpy_input_fn(
  x={WORDS_FEATURE: x_test},
  y=y_test,
  num_epochs=1,
  shuffle=False)
classifier.train(input_fn=train_input_fn, steps=100)

我从概念上知道 train_input_fn 正在向训练函数提供数据,但我如何手动调用此 fn 来检查其中的内容?

我跟踪了代码,发现 train_input_fn 函数将数据提供给以下 2 个变量:

features
Out[15]: {'words': <tf.Tensor 'random_shuffle_queue_DequeueMany:1' shape=(560, 10) dtype=int64>}

labels
Out[16]: <tf.Tensor 'random_shuffle_queue_DequeueMany:2' shape=(560,) dtype=int32>

当我尝试通过执行 sess.run(features) 来评估 features 变量时,我的终端似乎卡住了并​​停止响应。

检查这些变量内容的正确方法是什么?

谢谢!

最佳答案

基于numpy_input_fn documentation和行为(挂起)我想底层实现取决于队列运行器。队列运行器未启动时会发生挂起。尝试根据 this guide 将您的 session 运行脚本修改为如下内容:

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    try:
        for step in xrange(1000000):
            if coord.should_stop():
                break
            features_data = sess.run(features)
            print(features_data)

    except Exception, e:
        # Report exceptions to the coordinator.
        coord.request_stop(e)
    finally:
        # Terminate as usual. It is safe to call `coord.request_stop()` twice.
        coord.request_stop()
        coord.join(threads)

或者,我鼓励您查看 tf.data.Dataset 接口(interface)(在 tensorflow 1.3 或更早版本中可能是 tf.contrib.data.Dataset)。您可以获得类似的输入/标签张量,而无需使用 Dataset.from_tensor_slices 的队列。创建稍微复杂一些,但接口(interface)更加灵活,并且实现不使用队列运行器,这意味着 session 运行要简单得多。

import tensorflow as tf
import numpy as np

x_data = np.random.random((100000, 2))
y_data = np.random.random((100000,))

batch_size = 2
buff = 100


def input_fn():
    # possible tf.contrib.data.Dataset.from... in tf 1.3 or earlier
    dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data))
    dataset = dataset.repeat().shuffle(buff).batch(batch_size)
    x, y = dataset.make_one_shot_iterator().get_next()
    return x, y


x, y = input_fn()
with tf.Session() as sess:
    print(sess.run([x, y]))

关于python - Tensorflow:对 tf.estimator.inputs.numpy_input_fn 函数进行故障排除,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46762932/

相关文章:

java - 将数据从java存储到cassandra

php - Python 到 PHP 的行

linux - 如何在后台运行 Keras、tensorflow 程序并保留所有日志?

model - Tensorflow Slim 恢复模型并预测

python - Python 中的背包(什么?)

python - 数据存储 : 中的属性用户已损坏

python - OSX 上的 Tensorflow 0.10 (CUDA) 在 python 导入时出现段错误

tensorflow - Keras/Tensorflow GPU 使用率低?

tensorflow - TensorFlow 作业是否默认使用多个内核?

python - 如何使用 PyAudio 选择特定的输入设备