我正在使用Tensorflow Object Detection API使用迁移学习训练对象检测模型。具体来说,我正在使用ssd_mobilenet_v1_fpn_coco from the model zoo ,并使用sample pipeline provided ,当然用我的训练和评估 tfrecords 和标签的实际链接替换了占位符。
我能够使用上述管道在大约 5000 张图像(以及相应的边界框)上成功训练模型(如果相关的话,我主要在 TPU 上使用 Google 的 ML 引擎)。
现在,我准备了额外的约 2000 张图像,并希望继续使用这些新图像训练我的模型,而无需从头开始(训练初始模型花费了约 6 小时的 TPU 时间)。我怎样才能做到这一点?
最佳答案
您有两个选项,在这两个选项中,您都需要更改新数据集的 train_input_reader
的 input_path
:
- 在训练配置中指定要微调的检查点时,请指定训练模型的检查点
train_config{
fine_tune_checkpoint: <path_to_your_checkpoint>
fine_tune_checkpoint_type: "detection"
load_all_detection_checkpoint_vars: true
}
- 只需继续使用与先前模型相同的
model_dir
的相同配置(train_input_reader
除外)即可。这样,API 将创建一个图表,并检查model_dir
中是否已存在检查点并且适合该图表。如果是这样 - 它将恢复它并继续训练它。
编辑:fine_tune_checkpoint_type之前被错误地设置为true,而一般情况下它应该是“检测”或“分类”,在这种特定情况下应该是“检测”。感谢克里什的关注。
关于tensorflow - 如何使用 Tensorflow 对象检测 API 继续训练对象检测模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53104300/