我正在尝试使用 KNearest Neighbors 解决多类机器学习问题,并使用 Matplotlib.pyplot 的 imshow 为数据中所有 10 个类的预测绘制混淆矩阵。有些类在数据中的出现次数比其他类多得多,最多 3000 个,而其他类可能只有 50 个,因此我将其标准化以仅显示百分比。图表旁边有一个颜色条,如果没有标准化,其范围将是 1 到 3000,这是有道理的。然而,在标准化之后,范围一直保持在 3000。我使用的是 Scikit Learn 自己在其网站 here 提供的绘图函数。 。我是否遗漏了一些明显的东西,或者是否有额外的步骤来减少颜色条值范围?
代码
virdis = plt.cm.viridis
blues = plt.cm.Blues
autumn = plt.cm.autumn
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
bounds=[0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1]
plt.colorbar(boundaries=bounds)
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
cm = np.around(cm, decimals=3)
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i, j],
horizontalalignment="center",
color="white" if i == 9 and j == 9 else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
knn = KNeighborsClassifier()
knn.fit(X_train, y_train)
knn_score = knn.score(X_test, y_test)
knn_fold_score = model_selection.cross_val_score(knn, X_test, y_test, cv=10).mean()
predictions = knn.predict(X_test)
c_matrix = confusion_matrix(y_test, predictions)
# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(c_matrix, classes=country_names, normalize=True,
title='Normalized confusion matrix')
plt.show()
最佳答案
正如您所理解的,颜色条及其范围保留并且应该保留在绘图中,即 plt.imshow
。 Scikit Learn 示例和您的示例都在进行标准化或决定是否进行标准化之前绘制矩阵。因此,这两个图及其关联的颜色条看起来完全相同。如果您在绘图之前处理归一化,即移动以下 block :
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
cm = np.around(cm, decimals=3)
在 plt.imshow(cm, interpolation='nearest', cmap=cmap)
前面,标准化绘图的颜色条范围将从 0 到 1。再次提醒您,绘图本身(的颜色)也会改变。我认为仅将颜色条的文本标签更改为 0 到 1 的范围而不更改颜色条本身及其关联图并不是一个好主意。
关于python - 具有归一化混淆矩阵的 Matplotlib 图的颜色条不会更新值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45251905/