python - 在 visualize_cv2.py 中列出超出范围的索引

标签 python python-3.x list data-science index-error

我想从 visualize_cv2.py 启动 mask-rcnn 模型。我的目标是只训练来自 class_names - person 的 1 个元素。为此,我创建了 class_names1(从这个 python 文件中添加了完整代码以便更好地理解):

import cv2
import numpy as np
import os
import sys
import coco
import utils
import model as modellib

ROOT_DIR = os.getcwd()
MODEL_DIR = os.path.join(ROOT_DIR, "logs")
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
if not os.path.exists(COCO_MODEL_PATH):
    utils.download_trained_weights(COCO_MODEL_PATH)


class InferenceConfig(coco.CocoConfig):
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1


config = InferenceConfig()
config.display()

model = modellib.MaskRCNN(
    mode="inference", model_dir=MODEL_DIR, config=config
)
model.load_weights(COCO_MODEL_PATH, by_name=True)
class_names = [
        'BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
        'bus', 'train', 'truck', 'boat', 'traffic light',
        'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird',
        'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
        'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
        'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
        'kite', 'baseball bat', 'baseball glove', 'skateboard',
        'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
        'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
        'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
        'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
        'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
        'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
        'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
        'teddy bear', 'hair drier', 'toothbrush'
    ]
class_names1 = class_names[1]

def random_colors(N):
    np.random.seed(1)
    colors = [tuple(255 * np.random.rand(3)) for _ in range(N)]
    return colors


colors = random_colors(len(class_names))
class_dict = {
    name: color for name, color in zip(class_names, colors)
}


def apply_mask(image, mask, color, alpha=0.5):
    """apply mask to image"""
    for n, c in enumerate(color):
        image[:, :, n] = np.where(
            mask == 1,
            image[:, :, n] * (1 - alpha) + alpha * c,
            image[:, :, n]
        )
    return image


def display_instances(image, boxes, masks, ids, names, scores):
    """
        take the image and results and apply the mask, box, and Label
    """
    n_instances = boxes.shape[0]

if not n_instances:
    print('NO INSTANCES TO DISPLAY')
else:
    assert boxes.shape[0] == masks.shape[-1] == ids.shape[0]

for i in range(n_instances):
    if not np.any(boxes[i]):
        continue

    y1, x1, y2, x2 = boxes[i]
    label = names[ids[i]]
    color = class_dict[label]
    score = scores[i] if scores is not None else None
    caption = '{} {:.2f}'.format(label, score) if score else label
    mask = masks[:, :, i]

    image = apply_mask(image, mask, color)
    image = cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
    image = cv2.putText(
        image, caption, (x1, y1), cv2.FONT_HERSHEY_COMPLEX, 0.7, color, 2
    )

return image


 if __name__ == __main__:


capture = cv2.VideoCapture(0)


# these 2 lines can be removed if you dont have a 1080p camera.
#capture.set(cv2.CAP_PROP_FRAME_WIDTH, 1920)
#capture.set(cv2.CAP_PROP_FRAME_HEIGHT, 1080)

while True:
    ret, frame = capture.read()
    results = model.detect([frame], verbose=0)
    r = results[0]
    frame = display_instances(
        frame, r['rois'], r['masks'], r['class_ids'], class_names, r['scores']
    )
    cv2.imshow('frame', frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

capture.release()
cv2.destroyAllWindows()

但是如果我运行它,我会得到一个错误:

Traceback (most recent call last):

File "visualize_cv2.py", line 86, in display_instances
label = names[ids[i]]
IndexError: string index out of range

正如我所想,我需要将此行 (86) 更改为某些内容。但是不明白如何(我是 python 的新手)。

最佳答案

问题是它需要一个位于索引 0 的类作为背景,而您的类新定义的类从索引 1 开始。因此将代码更改为

class_names = [
    'BG', 'person'
]

它会解决这个问题。 如果你需要更多的类,那么只需在人之后添加,'BG' 必须始终具有索引 0

class_names = [
    'BG', 'person', 'some other class', 'some other class', '...'
]

关于python - 在 visualize_cv2.py 中列出超出范围的索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50229639/

相关文章:

python - 如何创建一个numpy列表数组?

python - Django 序列化程序 : Getting an ordered dictionary in embeded serializer

Python 转换查询输出

python - 绘制 XGBoost 模型特征重要性的增益、覆盖率、权重

python - 如何在 numpy 中将 int 值更改/重新映射到 str

也可以充当客户端的Python服务器

python - 访问嵌套序列化器字段中的序列化器实例

python - 使用 numpy nan 查找列表的最大值

c# - 通过 lambda 从另一个集合中排除一个集合

python - 在 Python 中,我如何按字典的某个值 + 字母顺序对字典列表进行排序?