我正在进行一个机器人项目,涉及检测人体,为此我正在使用 tensorflow 和预定义的数据集来创建训练模型。由于我是机器学习的新手,因此我无法从分类器中正确获取输出。我只需要检测“人”,并希望避免检测到球,笔记本电脑或其他物体。
现在,我的网络摄像头可以检测到所有物体,例如球,球棒,笔记本电脑,电视等。我需要的输出仅是得分达到80%或以上的人。
我用于创建模型的代码是
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
from utils import label_map_util
from utils import visualization_utils as vis_util
MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017'
MODEL_FILE = MODEL_NAME + '.tar.gz'
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')
NUM_CLASSES = 90
if not os.path.exists(MODEL_NAME + '/frozen_inference_graph.pb'):
print ('Downloading the model')
opener = urllib.request.URLopener()
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
file_name = os.path.basename(file.name)
if 'frozen_inference_graph.pb' in file_name:
tar_file.extract(file, os.getcwd())
print ('Download complete')
else:
print ('Model already exists')
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
import cv2
cap = cv2.VideoCapture(1)
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
ret = True
while (ret):
ret,image_np = cap.read()
image_np_expanded = np.expand_dims(image_np, axis=0)
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
scores = detection_graph.get_tensor_by_name('detection_scores:0')
classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
vis_util.visualize_boxes_and_labels_on_image_array(image_np,np.squeeze(boxes),np.squeeze(classes).astype(np.int32),np.squeeze(scores),category_index,use_normalized_coordinates=True,line_thickness=8)
cv2.imshow('image',cv2.resize(image_np,(1280,960)))
if cv2.waitKey(27) & 0xFF == ord('q'):
cv2.destroyAllWindows()
cap.release()
break
谁能解释我如何只检测准确度得分超过80%的人。
最佳答案
从文档here可以看到,您只需要检查person类。现在,vis_util
检查所有类。您只需要为person类添加if
条件。下面给出的是适当的标识符(来自docs)。
item {
name: "/m/01g317"
id: 1
display_name: "person"
}
关于python - 使用opencv,tensorflow和python进行人体检测,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50642057/