python - 如何在 scikit-learn 中使用 LDA(线性判别式)进行预测?

标签 python machine-learning scikit-learn pca linear-discriminant

我一直在测试 PCA 和 LDA 对我想要自动识别的 3 种不同类型的图像标签进行分类的效果。在我的代码中,X 是我的数据矩阵,其中每一行都是图像的像素,y 是一个一维数组,说明每一行的分类。

import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.decomposition import PCA
from sklearn.lda import LDA

pca = PCA(n_components=2)
X_r = pca.fit(X).transform(X)

plt.figure(figsize = (35, 20))
plt.scatter(X_r[:, 0], X_r[:, 1], c=y, s=200)

lda = LDA(n_components=2)
X_lda = lda.fit(X, y).transform(X)
plt.figure(figsize = (35, 20))
plt.scatter(X_lda[:, 0], X_lda[:, 1], c=y, s=200)

使用 LDA,我最终得到了 3 个清晰可辨的集群,它们之间只有轻微的重叠。现在,如果我有一张我想要分类的新图像,一旦我将它变成一维数组,我该如何预测它应该落入哪个簇,如果它离中心太远,我怎么能说分类是“不确定的” “?我也很好奇“.transform(X)”函数在我拟合后对我的数据做了什么。

最佳答案

在使用一些数据 X 训练 LDA 模型后,您可能想要投影一些其他数据 Z。在这种情况下你应该做的是:

lda = LDA(n_components=2) #creating a LDA object
lda = lda.fit(X, y) #learning the projection matrix
X_lda = lda.transform(X) #using the model to project X 
# .... getting Z as test data....
Z = lda.transform(Z) #using the model to project Z
z_labels = lda.predict(Z) #gives you the predicted label for each sample
z_prob = lda.predict_proba(Z) #the probability of each sample to belong to each class

请注意,“拟合”用于拟合模型,而不是拟合数据

因此,transform 用于构建表示(在本例中为投影),predict 用于预测每个样本的标签。 (这用于从 sklearn 中的 BaseEstimator 继承的所有类。

您可以阅读 documentation更多选项和属性。

此外,sklearn 的 API 允许您执行 pca.fit_transform(X) 而不是 pca.fit(X).transform(X)。如果您在代码中的这一点之后对模型本身不感兴趣,请使用此版本。

一些评论: 由于 PCA 是 Unsupervised approach , LDA 是一种更好的方法来执行您当前正在执行的这种“视觉”分类。

此外,如果您对分类感兴趣,您可以考虑使用不同类型的分类器,不一定是 LDA,尽管它是一种很好的可视化方法。

关于python - 如何在 scikit-learn 中使用 LDA(线性判别式)进行预测?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/31107945/

相关文章:

python - 使用 TensorFlow 优化器优化涉及 tf.keras 的 "model.predict()"的函数?

python - 在 Python 中使用 sklearn.preprocessing 进行数据转换

python - Scikit学习: Applying Mean Shift on a multi-dimensional dataset

python - 将 pandas DataFrame 列拆分为 OneHot/Binary 列

python - 表示要存储在文本文件中的图形的最佳方式

python - random.choice 在 Python 2 和 3 上给出不同的结果

python - 使用感知器反向传播的问题

algorithm - 内容相关搜索:对购物产品进行分类

machine-learning - 每个节点需要多少个突触?

python - matplotlib : how to add legend? 中的 3D PCA