python - 如何在 TensorFlow 对象检测 API 中保留类特定权重的同时重置类

标签 python tensorflow object-detection object-detection-api transfer-learning

我目前正在使用 TensorFlow Object Detection API并试图从模型动物园微调预训练的 Faster-RCNN。目前,如果我选择与原始网络中使用的数量不同的类别,它不会从 SecondStageBoxPredictor/ClassPredictor 初始化权重和偏差,因为它现在具有与原始 类预测器。但是,由于我想在网络上训练的所有类都是原始网络经过训练可以识别的类,因此我想保留与我想在 SecondStageBoxPredictor/ClassPredictor 中使用的类相关的权重和偏差 并修剪所有其他值,而不是简单地从头开始初始化这些值(类似于 this function 的行为)。

这是否可能?如果可能,我将如何在 Estimator 中修改该层的结构?

注意This question问了类似的事情,他们的回答是忽略网络输出中不相关的类——然而,在这种情况下,我试图微调网络,我假设这些冗余类的存在会使训练/评估过程复杂化?

最佳答案

如果您想要训练网络的所有类别都是网络经过训练可以识别的类别,您可以简单地使用网络进行检测,不是吗?

但是,如果你有额外的类(class)并且你想进行迁移学习,你可以通过设置从检查点恢复尽可能多的变量:

fine_tune_checkpoint_type: 'detection'
load_all_detection_checkpoint_vars: True

在管道配置文件的 train_config 字段中。

最后,通过查看计算图,可以看出 SecondStageBoxPredictor/ClassPredictor/weights 的形状取决于输出类的数量。 enter image description here

请注意,在tensorflow中你只能在变量级别恢复,如果两个变量有不同的形状,一个不能用一个来初始化另一个。因此,在您的情况下,保留 weights 变量的某些值的想法是不可行的。

关于python - 如何在 TensorFlow 对象检测 API 中保留类特定权重的同时重置类,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55651492/

相关文章:

python - 使用 Python OpenCV 检测图像中的所有圆圈(光学标记识别)

tensorflow - 如何在 TF Object Detection 2.0 中分别加载已保存的 Faster R-CNN 的两个阶段?

c++ - 精炼 Haar 检测

python - 为什么我检查井字游戏获胜者的方法不起作用?

python - 用于解析简单英语定义的正则表达式示例

python - 使用tf.nn.dynamic_rnn制作多个隐藏层的LSTM RNN

python - 创建一个内容大于 2GB 的张量原型(prototype)

python - 如何在字典中获取多个最大键值?

python - 自动从 excel 文件发送批量电子邮件

python - NLP AI 逻辑 - 每个序列架构具有多个参数的对话序列