python - 如果未设置 tf.stop_gradient 会怎样?

标签 python tensorflow object-detection tensorflow-model-analysis

我正在阅读 tensorflow 模型的 faster-rcnn 代码。我对 tf.stop_gradient 的使用感到困惑。

考虑以下代码片段:

if self._is_training:
    proposal_boxes = tf.stop_gradient(proposal_boxes)
    if not self._hard_example_miner:
    (groundtruth_boxlists, groundtruth_classes_with_background_list, _,
     groundtruth_weights_list
    ) = self._format_groundtruth_data(true_image_shapes)
    (proposal_boxes, proposal_scores,
     num_proposals) = self._sample_box_classifier_batch(
         proposal_boxes, proposal_scores, num_proposals,
         groundtruth_boxlists, groundtruth_classes_with_background_list,
         groundtruth_weights_list)

更多代码为here .我的问题是:如果未为 proposal_boxes 设置 tf.stop_gradient 会怎样?

最佳答案

这真是个好问题,因为这条简单的tf.stop_gradient 行在训练faster_rcnn 模型时非常关键。这就是为什么在培训期间需要它。

Faster_rcnn 模型是两阶段检测器,损失函数必须满足两个阶段的目标。在 faster_rcnn 中,rpn 损失和 fast_rcnn 损失都需要最小化。

这是论文第 3.2 节的内容

Both RPN and Fast R-CNN, trained independently will modify their convlolutional layers in different ways. We therefore need to develop a technique that allows for sharing convolutional layers between the two networks, rather than learning two separate networks.

然后论文描述了三种训练方案,在原论文中他们采用了第一种方案——Alternating training,即先训练RPN再训练Fast-RCNN。

第二种方案是近似联合训练,实现简单,API采用该方案。 Fast R-CNN 接受来自预测边界框的输入坐标(通过 rpn),因此 Fast R-CNN 损失将具有 w.r.t 边界框坐标的梯度。但在这个训练方案中,这些梯度被忽略,这正是使用tf.stop_gradient的原因。论文报道,这种训练方案将减少训练时间25-50%。

第三种方案是非近似联合训练,所以不需要tf.stop_gradient。该论文报告说,拥有一个可微分 w.r.t 框坐标的 RoI 池化层是一个非常重要的问题。

但是为什么这些梯度被忽略了呢?

事实证明,RoI 池化层是完全可微的,但支持方案二的主要原因是方案三会导致它在训练早期不稳定。

API 的一位作者给出了非常好的答案 here

一些 further reading关于近似联合训练。

关于python - 如果未设置 tf.stop_gradient 会怎样?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56059078/

相关文章:

python - 这个 tensorflow 安装有什么问题?我已经安装了GPU版本的tensorflow

python - 谷歌云平台,机器学习引擎, "No module named absl"

python - 有没有办法将变量传递给 Jinja2 parent ?

用于后缀的 Python 过滤器

tensorflow - Tensorboard,只显示最新的 tfevents

python - 将 Tensorflow 输入管道与 skflow/tf learn 结合使用

android - 使用OpenCV在Android上进行多对象检测

python - Tensorflow 对象检测 API 中的数据增强

python - 如何将 HTMLUnit 驱动程序与 Python 中的 Selenium 一起使用?

python - 将变量保持在范围内同时捕获异常的优雅/Pythonic 方法是什么?