python - 使用 TensorFlow Dataset API 的纪元计数器

标签 python tensorflow

我正在将我的 TensorFlow 代码从旧的队列接口(interface)更改为新的 Dataset API .在我的旧代码中,每次在队列中访问和处理新的输入张量时,我都会通过递增 tf.Variable 来跟踪纪元计数。我想使用新的 Dataset API 计算这个纪元,但我在让它工作时遇到了一些麻烦。

由于我在预处理阶段生成了可变数量的数据项,因此在训练循环中递增 (Python) 计数器并不是一件简单的事情 - 我需要根据队列或数据集的输入。

我模仿了以前使用旧队列系统的方式,这是我最终为数据集 API 得到的(简化示例):

with tf.Graph().as_default():

    data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
    input_tensors = (data,)

    epoch_counter = tf.Variable(initial_value=0.0, dtype=tf.float32,
                                trainable=False)

    def pre_processing_func(data_):
        data_size = tf.constant(0.1, dtype=tf.float32)
        epoch_counter_op = tf.assign_add(epoch_counter, data_size)
        with tf.control_dependencies([epoch_counter_op]):
            # normally I would do data-augmentation here
            results = (tf.expand_dims(data_, axis=0),)
            return tf.data.Dataset.from_tensor_slices(results)

    dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
    dataset = dataset_source.flat_map(pre_processing_func)
    dataset = dataset.repeat()
    # ... do something with 'dataset' and print
    # the value of 'epoch_counter' every once a while

但是,这是行不通的。它崩溃并显示一条神秘的错误消息:

 TypeError: In op 'AssignAdd', input types ([tf.float32, tf.float32])
 are not compatible with expected types ([tf.float32_ref, tf.float32])

仔细检查表明 epoch_counter 变量可能根本无法在 pre_processing_func 中访问。它可能存在于不同的图表中吗?

知道如何修正上面的例子吗?或者如何通过其他方式获取纪元计数器(带小数点,例如 0.4 或 2.9)?

最佳答案

TL;DR:将 epoch_counter 的定义替换为以下内容:

epoch_counter = tf.get_variable("epoch_counter", initializer=0.0,
                                trainable=False, use_resource=True)

tf.data.Dataset 转换中使用 TensorFlow 变量有一些限制。原则限制是所有变量都必须是“资源变量”,而不是旧的“引用变量”;不幸的是,tf.Variable 仍然出于向后兼容性的原因创建“引用变量”。

一般来说,如果可以避免,我不建议在 tf.data 管道中使用变量。例如,您可以使用 Dataset.range() 定义纪元计数器,然后执行如下操作:

epoch_counter = tf.data.Dataset.range(NUM_EPOCHS)
dataset = epoch_counter.flat_map(lambda i: tf.data.Dataset.zip(
    (pre_processing_func(data), tf.data.Dataset.from_tensors(i).repeat()))

上面的代码片段将纪元计数器作为第二个组件附加到每个值。

关于python - 使用 TensorFlow Dataset API 的纪元计数器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47410778/

相关文章:

python-3.x - 在Windows7上构建tensorflow

python - 使用 xml.etree.ElementTree 更改 xml 元素文本

python - 从 Python 源代码中提取注释

python - 根据从另一个数组中删除的项目从 Python 数组中删除项目

python - 错误 : file. whl 在此平台上不受支持的轮子

python - Tensorflow:计算梯度子张量

python - 如何添加验证特征/规则?

python - "Allocating size to..."在 Gtk.ScrolledWindow 中使用 Gtk.TreeView 时出现 GTK 警告

tensorflow - 如何直接写入模仿 scalar_summary 的摘要?

python - Tensorflow,其中(索引) 'and' 条件