<分区>
我想将数据集迭代器传递给函数,但该函数需要知道数据集的长度。在下面的示例中,我可以将 len(datafiles)
传递给 my_custom_fn()
函数,但我想知道我是否能够从中提取数据集的长度iterator
、batch_x
或 batch_y
类,这样我就不必将其添加为输入。
dataset = tf.data.FixedLengthRecordDataset(datafiles, record_bytes)
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer)
[batch_x, batch_y] = iterator.get_next()
value = my_custom_fn(batch_x, batch_y)
# lots of other stuff
谢谢!
编辑:此解决方案不适用于我的情况:tf.data.Dataset: how to get the dataset size (number of elements in a epoch)?
运行后
tf.data.Dataset.list_files('{}/*.dat')
tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0')[0])
返回
<tf.Tensor 'Shape_3:0' shape=(0,) dtype=int32>
我确实找到了适合我的解决方案。将 iterator_scope 添加到我的代码中,例如:
with tf.name_scope('iter'):
dataset = tf.data.FixedLengthRecordDataset(datafiles, record_bytes)
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer)
[batch_x, batch_y] = iterator.get_next()
value = my_custom_fn(batch_x, batch_y)
# lots of other stuff
然后从 my_custom_fn
调用:
def my_custom_fn(batch_x, batch_y):
filenames = batch_x.graph.get_operation_by_name(
'iter/InputDataSet/filenames').outputs[0]
n_epoch = sess.run(sess.graph.get_operation_by_name(
'iter/Iterator/count').outputs)[0]
batch_size = sess.run(sess.graph.get_operation_by_name(
'iter/Iterator/batch_size').outputs)[0]
# lots of other stuff
不确定这是否是最好的方法,但它似乎有效。很高兴就此提出任何建议,因为它看起来有点老套。