python - 在 TensorFlow Dataset API 中访问排队项的数量

标签 python tensorflow

我正在将我的 TensorFlow 代码从旧的队列接口(interface)更改为新的数据集 API。使用旧界面,我可以通过访问图中的原始计数器来监视实际填充的队列大小,例如如下:

queue = tf.train.shuffle_batch(...,  name="training_batch_queue")
queue_size_op = "training_batch_queue/random_shuffle_queue_Size:0"
queue_size = session.run(queue_size_op)

但是,使用新的数据集 API,我似乎无法在图表中找到与队列/数据集相关的任何变量,因此我的旧代码不再有效。有什么方法可以使用新的数据集 API 获取队列中的项目数(例如,在 tf.Dataset.prefetchtf.Dataset.shuffle 队列中)?

监控队列中的项目数量对我来说很重要,因为它告诉我很多关于队列中预处理行为的信息,包括是预处理还是余数(例如神经网络) 是速度瓶颈。

最佳答案

作为解决方法,可以保留一个计数器来指示队列中有多少项目。以下是定义计数器的方法:

 queue_size = tf.get_variable("queue_size", initializer=0,
                              trainable=False, use_resource=True)

然后,在预处理数据时(例如在 dataset.map 函数中),我们可以增加该计数器:

 def pre_processing():
    data_size = ... # compute this (could be just '1')
    queue_size_op = tf.assign_add(queue_size, data_size)  # adding items
    with tf.control_dependencies([queue_size_op]):
        # do the actual pre-processing here

每次我们使用一批数据运行我们的模型时,我们就可以减少计数器:

 def model():
    queue_size_op = tf.assign_add(queue_size, -batch_size)  # removing items
    with tf.control_dependencies([queue_size_op]):
        # define the actual model here

现在,我们需要做的就是在我们的训练循环中运行 queue_size 张量来找出当前队列大小,即此时队列中的项目数:

 current_queue_size = session.run(queue_size)

与旧方法(在数据集 API 之前)相比,它有点不那么优雅,但它可以解决问题。

关于python - 在 TensorFlow Dataset API 中访问排队项的数量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47412762/

相关文章:

python - 我可以在 Tkinter 中注册更改事件的回调吗?

Python:对音乐文件执行 FFT

python - 如何使用 Python 3.8 安装 TensorFlow

python - Tensorflow:无法将 tf.case 与输入参数一起使用

python - 导入错误 : cannot import name 'feature_column_v2' from 'tensorflow.python.tpu' using Object Detection API

python - 在 Tensorflow TPU 上乘以大量矩阵和向量

tensorflow - TensorFlow 中的资格跟踪

python - vim 在函数和类定义下添加自动 sphinx 注释

python - sys.setdefaultencoding ('utf-8' 的危险)

python - 使用 Beautiful Soup 查找下一个出现的标签及其包含的文本