python - sklearn kmeans 上的预测方法,它是如何工作的以及它在做什么?

标签 python scikit-learn k-means

我一直在玩 sklearn 的 k-means 聚类类,但我对其预测方法感到困惑。

我在 iris 数据集上应用了一个模型,如下所示:

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

pca = PCA(n_components = 2).fit(X_train)

X_train_pca = pca.transform(X_train)
X_test_pca = pca.transform(X_test)

kmeans_pca = KMeans(n_clusters=3).fit(X_train_pca)

并做出预测:

pred = kmeans_pca.predict(X_test_pca)

print(classification_report(y_test, pred))

          precision    recall  f1-score   support

       0       1.00      1.00      1.00        19
       1       0.76      0.87      0.81        15
       2       0.86      0.75      0.80        16

    accuracy                           0.88        50
   macro avg       0.87      0.87      0.87        50
weighted avg       0.88      0.88      0.88        50

预测似乎很准确,这让我很困惑,因为我没有将标签传递给训练集。我读过这篇文章What is the use of predict() method in kmeans implementation of scikit learn?这告诉我预测方法正在调用距离测试数据最近的聚类质心。但是,我不知道 sklearn 如何在训练阶段正确分配 ID(即 kmeans_pca.labels_ 到相应的 y_train),因为训练阶段不涉及标签。

我意识到 k-means 不用于分类任务,但我想知道这些结果是如何实现的。这样,在 sklearn 中执行 k-means 聚类时 .predict() 可以起到什么作用呢?

最佳答案

KMeans 聚类是无监督学习的一个示例。这意味着,它确实没有考虑任何训练标签。

相反,示例完全根据特征之间的模式进行聚类 - 相似的示例被分组在一起。对于 Iris 数据集,同一朵花的不同示例往往具有相似的萼片和花瓣长度和宽度(即花的“特征”)。这意味着仅这些功能就给出了如何对花朵进行分组 - 无需提供明确的标签。

要了解结果是如何实现的,了解算法可能会有所帮助。以下是最常见的 KMeans 算法,基于以下步骤:

  1. 初始化 K 个不同的簇质心(可能随机,但不一定)
  2. 将每个示例分配给最近的聚类(例如,基于特征向量和聚类质心之间的欧几里德距离)
  3. 根据第 2 步中找到的集群成员重新计算集群质心。

重复步骤 2 和 3,直到收敛(即当集群分配不再改变时)。

上述算法最终将相似的示例分配给相同的集群,因此只关心特征之间的相似性,而不关心它们的标签。

.predict() 方法将为您提供任何测试示例中最有可能的聚类分配(例如“花”,如上所述)。事实上,这是通过分配到最近的簇质心来完成的,如上所述。

关于python - sklearn kmeans 上的预测方法,它是如何工作的以及它在做什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/72862052/

相关文章:

Python ValueError : non-broadcastable output operand with shape (124, 1) 与广播形状不匹配 (124,13)

algorithm - 在 k 均值聚类中最小化损失函数意味着什么?

python - 快速计算整个数据集到每个聚类中心的距离

python - 使用 sklearn 使用 Keras 数据生成器绘制混淆矩阵

Python - 使用 K-means 聚类。一些方差为零的列

python - 用户输入困难并替换字母

python - 绘制截断正态分布

python - 为什么 PyCrypto 不使用默认 IV?

python - 如何在 plotly 散点图中添加图像而不是点(python)

python - 在 svm 中预测多类