tensorflow - 如何使用 Tensorflow 对象检测 API 继续训练对象检测模型?

标签 tensorflow machine-learning google-cloud-ml object-detection-api

我正在使用Tensorflow Object Detection API使用迁移学习训练对象检测模型。具体来说,我正在使用ssd_mobilenet_v1_fpn_coco from the model zoo ,并使用sample pipeline provided ,当然用我的训练和评估 tfrecords 和标签的实际链接替换了占位符。

我能够使用上述管道在大约 5000 张图像(以及相应的边界框)上成功训练模型(如果相关的话,我主要在 TPU 上使用 Google 的 ML 引擎)。

现在,我准备了额外的约 2000 张图像,并希望继续使用这些新图像训练我的模型,而无需从头开始(训练初始模型花费了约 6 小时的 TPU 时间)。我怎样才能做到这一点?

最佳答案

您有两个选项,在这两个选项中,您都需要更改新数据集的 train_input_readerinput_path:

  1. 在训练配置中指定要微调的检查点时,请指定训练模型的检查点
train_config{
    fine_tune_checkpoint: <path_to_your_checkpoint>
    fine_tune_checkpoint_type: "detection"
    load_all_detection_checkpoint_vars: true
}
  • 只需继续使用与先前模型相同的 model_dir 的相同配置(train_input_reader 除外)即可。这样,API 将创建一个图表,并检查 model_dir 中是否已存在检查点并且适合该图表。如果是这样 - 它将恢复它并继续训练它。
  • 编辑:fine_tune_checkpoint_type之前被错误地设置为true,而一般情况下它应该是“检测”或“分类”,在这种特定情况下应该是“检测”。感谢克里什的关注。

    关于tensorflow - 如何使用 Tensorflow 对象检测 API 继续训练对象检测模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53104300/

    相关文章:

    python - 我无法导入 tensorflow-gpu

    tensorflow - 如何使用Tensorflow张量设置Keras层的输入?

    python - tensorflow-gpu 中的 "' CXXABI_1.3.8 ' not found"- 从源安装

    machine-learning - 遗传算法如何用数值数据演化出解决方案?

    google-cloud-platform - GCS/GCML 服务因错误错误而被阻止(超出管理 CRUD 配额)

    android - 来自 Keras 的卡住模型在恢复后无法预测

    machine-learning - 决策树学习和杂质

    python - Keras 模型根本不学习

    machine-learning - 导入错误: cannot import name JSONClient

    google-cloud-platform - 从 Google Cloud Datalab 笔记本下载保存的模型