tensorflow - Mask RCNN 仅 1 类

标签 tensorflow image-processing keras computer-vision object-detection

我希望只使用一个类,人(连同 BG,背景)来进行 Mask RCNN 对象检测。我正在使用此链接:https://github.com/matterport/Mask_RCNN运行掩码 rcnn。有没有一种特定的方法来完成这个(编辑特定文件,创建一个额外的 python 文件,或者只是通过过滤来自 class_names 数组的选择)?任何方向或解决方案将不胜感激。谢谢

最佳答案

我已经为绵羊训练了相同的代码库。你必须做两件事:

  1. 将训练和推理类数更改为 1 + 1(bg 和 person):

     class SheepsConfig(Config):
    
         NAME = "sheeps"
         NUM_CLASSES = 1 + 1 # background + sheep
    
     config = SheepsConfig()  # Don't forget to use this config while creating your model
     config.display()
    
  2. 您需要创建用于训练的数据集。您可以按如下方式使用 coco:

     import coco
     from pycocotools.coco import COCO
    
     ct = COCO("/YourPathToCocoDataset/annotations/instances_train2014.json")
     ct.getCatIds(['sheep']) 
     # Sheep class' id is 20. You should run for person and use that id
    
     COCO_DIR = "/YourPathToCocoDataset/"
     # This path has train2014, annotations and val2014 files in it
    
     # Training dataset
     dataset_train = coco.CocoDataset()
     dataset_train.load_coco(COCO_DIR, "train", class_ids=[20])
     dataset_train.prepare()
    
     # Validation dataset
     dataset_val = coco.CocoDataset()
     dataset_val.load_coco(COCO_DIR, "val", class_ids=[20])
     dataset_val.prepare()
    

然后简单地创建您的模型:

# Create model in training mode
model = modellib.MaskRCNN(mode="training", config=config, model_dir=MODEL_DIR)
model.load_weights(COCO_MODEL_PATH, by_name=True, exclude=["mrcnn_class_logits", "mrcnn_bbox_fc", "mrcnn_bbox", "mrcnn_mask"])
# This COCO_MODEL_PATH is the path to the mask_rcnn_coco.h5 file in this repo

然后你可以用这段代码训练它:

model.train(dataset_train, dataset_val,
        learning_rate=config.LEARNING_RATE, 
        epochs=100, 
        layers='heads')#You can also use 'all' to train all network.

不要忘记使用 tensorflow 1.x 和 keras 2.1.0 :) 我可以使用这些版本进行训练。

关于tensorflow - Mask RCNN 仅 1 类,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65810714/

相关文章:

python - 如何获取 Tensorflow 中的所有集合?

python - 在 tensorflow 2.0 beta 中从 tf.data.Dataset 检索下一个元素

python - 致命标志解析错误 :Unknown command line flag 'logtostderr'

tensorflow - 名称 tf.Session 已弃用。请改用 tf.compat.v1.Session

java - 如何调整目录中的图像大小?

python - 无法从 tensorflow /keras 中加载的模型获取梯度

c++ - 优化的 float 模糊变化

python - 如何找到最右边黑色像素的位置

python - 使用(常量)参数保存/加载 Keras 模型

python - keras模型H5理论上是如何工作的