python - Sklearn LDA 分析不会生成二维

标签 python scikit-learn

我正在尝试在 matplotlib 图上绘制具有二元分类的 3 特征数据集。这适用于指南 (http://www.apnorton.com/blog/2016/12/19/Visualizing-Multidimensional-Data-in-Python/) 中提供的示例数据集,但是当我尝试插入自己的数据集时,LinearDiscriminantAnalysis 将只输出一维序列,无论我在“n_components”中输入的数字是多少。为什么这不适用于我自己的代码?

Data = pd.read_csv("DataFrame.csv", sep=";")
x = Data.iloc[:, [3, 5, 7]]
y = Data.iloc[:, 8]

lda = LDA(n_components=2)
lda_transformed = pd.DataFrame(lda.fit_transform(x, y))

plt.scatter(lda_transformed[y==0][0], lda_transformed[y==0][1], label='Loss', c='red')
plt.scatter(lda_transformed[y==1][0], lda_transformed[y==1][1], label='Win', c='blue')

plt.legend()
plt.show()

最佳答案

如果不同类别标签的数量 C 小于观察数量(几乎总是如此),则线性判别分析将始终产生 C - 1 区分组件。使用 sklearn API 中的 n_components 只是一种选择可能更少 组件的方法,例如在您知道要减少到什么维度的情况下。但是您永远无法使用 n_components 来获取更多 组件。

这在 Wikipedia section on Multiclass LDA 中进行了讨论.类间散布的定义为

\Sigma_{b} = (1 / C) \sum_{i}^{C}( (\mu_{i} - mu)(\mu_{i} - mu)^{T}

这是类别均值总体之间的经验协方差矩阵。 By definition ,这样的协方差矩阵的秩至多为 C - 1

... the variability between features will be contained in the subspace spanned by the eigenvectors corresponding to the C − 1 largest eigenvalues ...

因此,由于 LDA 使用类别均值 协方差矩阵的分解,这意味着它可以提供的降维是基于类别标签的数量,而不是基于样本大小或特征维度.

在您链接的示例中,有多少功能 并不重要。关键是这个例子使用了 3 个模拟的聚类中心,所以有 3 个类标签。这意味着线性判别分析可以将数据投影到一维或二维判别子空间。

但在您的数据中,您一开始只有 2 个类标签,这是一个二元问题。这意味着线性判别模型的维度最多可以是一维的,从字面上看是一条线,它形成了两个类之间的决策边界。在这种情况下,使用 LDA 进行降维只是将数据点投影到该分隔线的特定法向量上。

如果你想专门降维到二维,你可以尝试sklearn提供的许多其他算法:t-SNE、ISOMAP、PCA和内核PCA、随机投影、多维缩放等等。其中许多允许您选择投影空间的维度,直到原始特征维度,或者有时您甚至可以投影到更大的空间,例如内核 PCA。

关于python - Sklearn LDA 分析不会生成二维,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49955592/

相关文章:

python - 在Python中将月份和年份的列一起变成季度和年份的列

python - 使用conda安装opencv

python - Y错误的scikit Mixtypes

python - 将 Keras 集成到 SKLearn 管道?

python - 属性错误: 'WebElement' object has no attribute 'get_text' error extracting the text between the starting and ending tag using Selenium Python

python - Python中动态调用函数的方法是什么?

python - 使用 Pipeline 和 GridSearchCV 完成的训练数量

scikit-learn - TypeError : unhashable type

python - 当我在柴油中使用产量时,telnet 连接关闭

python - 图像中的对象检测 (HOG)