python - 具有归一化混淆矩阵的 Matplotlib 图的颜色条不会更新值

标签 python matplotlib machine-learning colorbar

我正在尝试使用 KNearest Neighbors 解决多类机器学习问题,并使用 Matplotlib.pyplot 的 imshow 为数据中所有 10 个类的预测绘制混淆矩阵。有些类在数据中的出现次数比其他类多得多,最多 3000 个,而其他类可能只有 50 个,因此我将其标准化以仅显示百分比。图表旁边有一个颜色条,如果没有标准化,其范围将是 1 到 3000,这是有道理的。然而,在标准化之后,范围一直保持在 3000。我使用的是 Scikit Learn 自己在其网站 here 提供的绘图函数。 。我是否遗漏了一些明显的东西,或者是否有额外的步骤来减少颜色条值范围?

代码

virdis = plt.cm.viridis
blues = plt.cm.Blues
autumn = plt.cm.autumn

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)

    bounds=[0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1]
    plt.colorbar(boundaries=bounds)

    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    cm = np.around(cm, decimals=3)

    thresh = cm.max() / 2.

    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if i == 9 and j == 9 else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')

knn = KNeighborsClassifier()
knn.fit(X_train, y_train)

knn_score = knn.score(X_test, y_test)
knn_fold_score = model_selection.cross_val_score(knn, X_test, y_test, cv=10).mean()
predictions = knn.predict(X_test)

c_matrix = confusion_matrix(y_test, predictions)

# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(c_matrix, classes=country_names, normalize=True,
                      title='Normalized confusion matrix')

plt.show()

enter image description here

最佳答案

正如您所理解的,颜色条及其范围保留并且应该保留在绘图中,即 plt.imshow。 Scikit Learn 示例和您的示例都在进行标准化或决定是否进行标准化之前绘制矩阵。因此,这两个图及其关联的颜色条看起来完全相同。如果您在绘图之前处理归一化,即移动以下 block :

if normalize:
    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    print("Normalized confusion matrix")
else:
    print('Confusion matrix, without normalization')

cm = np.around(cm, decimals=3)

plt.imshow(cm, interpolation='nearest', cmap=cmap) 前面,标准化绘图的颜色条范围将从 0 到 1。再次提醒您,绘图本身(的颜色)也会改变。我认为仅将颜色条的文本标签更改为 0 到 1 的范围而不更改颜色条本身及其关联图并不是一个好主意。

关于python - 具有归一化混淆矩阵的 Matplotlib 图的颜色条不会更新值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45251905/

相关文章:

python - 如何在 matplotlib 的主图外绘制箭头和矩形(用于蛋白质 sec 结构)?

python - 在 Python 中查找单词的所有变体(或时态)

machine-learning - 微调VGG最后一层非常慢

python - 检测一个键是否被按下 - python

Python:为双直方图对齐 bin 边缘之间的条

python - 如何将绘图标题中带有下划线的字符串设置为斜体

machine-learning - Dropout 应该插入到哪里?全连接层。?卷积层。?或两者。?

python - 将列标题添加到 pandas 数据框..但是 NAN 是所有数据,即使标题是相同的维度

python - 优化词梯

python - 在大矩阵中搜索值