我正在运行来自 text classification 的一些教程代码
我可以运行脚本并且它有效,但是当我尝试逐行运行它以试图了解每个步骤在做什么时,我在这一步有点困惑:
test_input_fn = tf.estimator.inputs.numpy_input_fn(
x={WORDS_FEATURE: x_test},
y=y_test,
num_epochs=1,
shuffle=False)
classifier.train(input_fn=train_input_fn, steps=100)
我从概念上知道 train_input_fn 正在向训练函数提供数据,但我如何手动调用此 fn 来检查其中的内容?
我跟踪了代码,发现 train_input_fn 函数将数据提供给以下 2 个变量:
features
Out[15]: {'words': <tf.Tensor 'random_shuffle_queue_DequeueMany:1' shape=(560, 10) dtype=int64>}
labels
Out[16]: <tf.Tensor 'random_shuffle_queue_DequeueMany:2' shape=(560,) dtype=int32>
当我尝试通过执行 sess.run(features) 来评估 features 变量时,我的终端似乎卡住了并停止响应。
检查这些变量内容的正确方法是什么?
谢谢!
最佳答案
基于numpy_input_fn
documentation和行为(挂起)我想底层实现取决于队列运行器。队列运行器未启动时会发生挂起。尝试根据 this guide 将您的 session 运行脚本修改为如下内容:
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
for step in xrange(1000000):
if coord.should_stop():
break
features_data = sess.run(features)
print(features_data)
except Exception, e:
# Report exceptions to the coordinator.
coord.request_stop(e)
finally:
# Terminate as usual. It is safe to call `coord.request_stop()` twice.
coord.request_stop()
coord.join(threads)
或者,我鼓励您查看 tf.data.Dataset
接口(interface)(在 tensorflow 1.3 或更早版本中可能是 tf.contrib.data.Dataset
)。您可以获得类似的输入/标签张量,而无需使用 Dataset.from_tensor_slices
的队列。创建稍微复杂一些,但接口(interface)更加灵活,并且实现不使用队列运行器,这意味着 session 运行要简单得多。
import tensorflow as tf
import numpy as np
x_data = np.random.random((100000, 2))
y_data = np.random.random((100000,))
batch_size = 2
buff = 100
def input_fn():
# possible tf.contrib.data.Dataset.from... in tf 1.3 or earlier
dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data))
dataset = dataset.repeat().shuffle(buff).batch(batch_size)
x, y = dataset.make_one_shot_iterator().get_next()
return x, y
x, y = input_fn()
with tf.Session() as sess:
print(sess.run([x, y]))
关于python - Tensorflow:对 tf.estimator.inputs.numpy_input_fn 函数进行故障排除,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46762932/