tensorflow - 如何将类添加到现有模型?

标签 tensorflow neural-network object-detection tfrecord

我已经使用 tensorflow 对象检测/SSD mobilenet 训练了一个模型。它工作得很好!

我想给它添加一个类 - 只是为了检测笔或其他东西。

我怎样才能做到这一点?

我已经创建了我的图像集,我只是找不到关于如何将单个类添加到现有模型的任何教程或信息。

谢谢!

最佳答案

您将类添加到现有模型的想法,用 tensorflow 对象检测 api 行话来说,就是在自定义数据集上重新训练自定义对象检测模型 (在这种情况下,您的笔数据集)。

关于如何使用 tensorflow 对象检测 api 构建自定义对象检测器,有很多很好的教程。

比如senddex贴了一个很好的step by step教程here .官方 github repo 页面也包含一些像这样的好教程:bringing in your own dataset ,从某种意义上说,这实际上与从预训练模型中添加或删除类相同。

但同样,我认为上述教程并没有达到 的确切目标。添加类(class) 对于模型,如果您有旧类和新类的数据并重新训练所有这些类,它只会添加新类。由于在您的情况下,您只有新类的数据,因此更正式地将其称为重新训练自定义对象检测模型。

关于tensorflow - 如何将类添加到现有模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56594290/

相关文章:

python - Keras:使用 model.train_on_batch() 和 model.fit() 获得不同的精度。可能是什么原因以及如何解决?

python-3.x - tensorflow (使用 Keras)中出现 'InvalidArgumentError: Incompatible shapes: [10,2] vs. [10]' 的原因是什么?

c++ - OpenCV:如何使用 Haar Classifier Cascade 提高眼睛检测的准确性?

python - 根据输入形状计算是否有差异? (使用 Tensorflow 的 Python 中的 CNN)

python-3.x - 以更少的内存加载模型检查点

python - Keras/theano 中的最大利润损失

python - 通过预处理提高神经网络的准确性

opencv3.0 - detectorMultiScale(a, b, c) 参数含义

python - 如果未设置 tf.stop_gradient 会怎样?

python - 为什么不这样训练 GAN?