python - 确定 tf.data.Dataset Tensorflow 中的记录数

标签 python tensorflow machine-learning deep-learning

<分区>

我想将数据集迭代器传递给函数,但该函数需要知道数据集的长度。在下面的示例中,我可以将 len(datafiles) 传递给 my_custom_fn() 函数,但我想知道我是否能够从中提取数据集的长度iteratorbatch_xbatch_y 类,这样我就不必将其添加为输入。

dataset = tf.data.FixedLengthRecordDataset(datafiles, record_bytes)
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer)
[batch_x, batch_y] = iterator.get_next()
value = my_custom_fn(batch_x, batch_y)
# lots of other stuff

谢谢!

编辑:此解决方案不适用于我的情况:tf.data.Dataset: how to get the dataset size (number of elements in a epoch)?

运行后

tf.data.Dataset.list_files('{}/*.dat')
tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0')[0])

返回

<tf.Tensor 'Shape_3:0' shape=(0,) dtype=int32>

我确实找到了适合我的解决方案。将 iterator_scope 添加到我的代码中,例如:

with tf.name_scope('iter'):
    dataset = tf.data.FixedLengthRecordDataset(datafiles, record_bytes)
    iterator = dataset.make_initializable_iterator()
    sess.run(iterator.initializer)
    [batch_x, batch_y] = iterator.get_next()
value = my_custom_fn(batch_x, batch_y)
# lots of other stuff

然后从 my_custom_fn 调用:

def my_custom_fn(batch_x, batch_y):
    filenames = batch_x.graph.get_operation_by_name(
                  'iter/InputDataSet/filenames').outputs[0]
    n_epoch = sess.run(sess.graph.get_operation_by_name(
                  'iter/Iterator/count').outputs)[0]
    batch_size = sess.run(sess.graph.get_operation_by_name(
                  'iter/Iterator/batch_size').outputs)[0]
    # lots of other stuff

不确定这是否是最好的方法,但它似乎有效。很高兴就此提出任何建议,因为它看起来有点老套。

最佳答案

iterator 的长度在您遍历它之前是未知的。您可以显式地将 len(datafiles) 传递给该函数,但如果您坚持数据的持久性,您可以简单地使该函数成为一个实例方法并将数据集的长度存储在对象中my_custom_fn 是一种方法。

不幸的是,作为一个迭代器,它不存储任何东西,它动态生成数据。但是,正如在 TensorFlow 的源代码中所发现的那样,有一个“私有(private)”变量 _batch_size 用于存储批量大小。您可以在此处查看源代码:TensorFlow source .

关于python - 确定 tf.data.Dataset Tensorflow 中的记录数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52264469/

相关文章:

python - 用soup.select在美汤中选二胎?

javascript - AJAX 请求无法读取更新的 session /全局变量

python - 回归运动损失巨大!使用批量输入进行测试?

python - 如何在 python 之外使用 Vowpal Wabbit 模型

python - 我的 python 包 '' 深度匹配器“安装有问题

machine-learning - 了解 LibSVM 中 SVM 参数的好资源

python - 扭曲的 adbapi : runInteraction last_insert_id()

python - 引发运行时错误 ('FPDF error: ' +msg) 运行时错误 : FPDF error: Unsupported image type: chapter_1_romance_dawn

python - 如何恢复logit

tensorflow - API.ai/RASA NLU 可以与 Tensorflow 集成来制作聊天机器人吗