我从训练数据集中的图像中提取了一些特征,然后应用了这些特征并使用 train_test_split
将数据分割为训练和测试:
Train data : (60, 772)
Test data : (20, 772)
Train labels: (60,)
Test labels : (20,)
接下来我想做的是将 SVM 分类器应用于测试数据集中的图像并查看结果。
# create the model - SVM
#clf = svm.SVC(kernel='linear', C=40)
clf = svm.SVC(kernel='rbf', C=10000.0, gamma=0.0001)
# fit the training data to the model
clf.fit(trainDataGlobal, trainLabelsGlobal)
# path to test data
test_path = "dataset/test"
# loop through the test images
for index,file in enumerate(glob.glob(test_path + "/*.jpg")):
# read the image
image = cv2.imread(file)
# resize the image
image = cv2.resize(image, fixed_size)
# predict label of test image
prediction = clf.predict(testDataGlobal)
prediction = prediction[index]
#print("Accuracy: {}%".format(clf.score(testDataGlobal, testLabelsGlobal) * 100 ))
# show predicted label on image
cv2.putText(image, train_labels[prediction], (20,30), cv2.FONT_HERSHEY_TRIPLEX, .7 , (0,255,255), 2)
# display the output image
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
plt.show()
尽管我可以看到它说准确度为 60%,但我并没有获得很好的准确度。然而大多数图像的标签是错误的。我在预测
中传递了错误的参数吗?
我可以做些什么来改善这个问题?
编辑:我已尝试使用以下代码执行您所说的操作,但收到一条错误消息,提示我应该 reshape 我的 feature_vector
。所以我这样做了,然后出现以下错误。
(作为引用:feature_extraction_method(image).shape
为 (772,)
。)
for filename in test_images:
# read the image and resize it to a fixed-size
img = cv2.imread(filename)
img = cv2.resize(img, fixed_size)
feature_vector = feature_extraction_method(img)
prediction = clf.predict(feature_vector.reshape(-1, 1))
cv2.putText(img, prediction, (20, 30), cv2.FONT_HERSHEY_TRIPLEX, .7 , (0, 255, 255), 2)
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.show()
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-71-2b8ff4146d8e> in <module>()
19
20 feature_vector = feature_extraction_method(img)
---> 21 prediction = clf.predict(feature_vector.reshape(-1, 1))
22 cv2.putText(img, prediction, (20, 30), cv2.FONT_HERSHEY_TRIPLEX, .7 , (0, 255, 255), 2)
23 plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
/anaconda3/lib/python3.6/site-packages/sklearn/svm/base.py in predict(self, X)
546 Class labels for samples in X.
547 """
--> 548 y = super(BaseSVC, self).predict(X)
549 return self.classes_.take(np.asarray(y, dtype=np.intp))
550
/anaconda3/lib/python3.6/site-packages/sklearn/svm/base.py in predict(self, X)
306 y_pred : array, shape (n_samples,)
307 """
--> 308 X = self._validate_for_predict(X)
309 predict = self._sparse_predict if self._sparse else self._dense_predict
310 return predict(X)
/anaconda3/lib/python3.6/site-packages/sklearn/svm/base.py in _validate_for_predict(self, X)
457 raise ValueError("X.shape[1] = %d should be equal to %d, "
458 "the number of features at training time" %
--> 459 (n_features, self.shape_fit_[1]))
460 return X
461
ValueError: X.shape[1] = 1 should be equal to 772, the number of features at training time
最佳答案
您的代码有两个主要问题。
首先,您不需要在 for 循环的每次迭代中对整个测试集进行分类。一次预测一张图像的类标签就足够了:
prediction = svm.clf.predict([testDataGlobal[index, :]])
请注意,testDataGlobal[index, :]
必须括在方括号 [ ]
中,因为 predict()
方法需要一个 2D 数组-类似变量。
第二,也是最重要的,让我们假设函数 glob
生成三个图像文件的列表,即 imgA.jpg
、imgB.jpg
和 imgC.jpg
,让我们将它们对应的特征向量表示为 featsA
、featsB
和 featsC
。为了使您的代码正常工作,testDataGlobal
必须按如下方式排列:
[featsA,
featsB,
featsC]
如果特征向量以不同的顺序排列,您可能会得到错误的结果。
您可以通过以下代码片段正确标记图像(未经测试):
test_images = glob.glob("dataset/test/*.jpg")
for filename in test_images:
img = cv2.imread(filename)
img = cv2.resize(img, fixed_size)
feature_vector = your_feature_extraction_method(img)
prediction = svm.clf.predict([feature_vector])
cv2.putText(img, prediction[0], (20, 30),
cv2.FONT_HERSHEY_TRIPLEX, .7 , (0, 255, 255), 2)
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.show()
其中 your_feature_extraction_method()
代表使用图像并返回其特征向量(类似一维数组)的函数。
注意:不要忘记将 feature_vector
括在方括号 [ ]
中。您还可以使用以下任一方法将 feature_vector
的维度增加一维:
prediction = svm.clf.predict(feature_vector[None, :])
prediction = svm.clf.predict(feature_vector[np.newaxis, :])
prediction = svm.clf.predict(np.atleast_2d(feature_vector))
关于python - SVM 图像预测 Python,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54911334/