目标
我正在尝试复制此 paper 中描述的应用程序(第 4.1 节),其中稀疏主成分分析应用于文本语料库,输出为 K 个主成分,每个主成分显示“否则隐藏的结构”。换句话说,每个主要组成部分都应包含一个单词列表,所有这些单词都有一个共同的主题。
我已经使用 sklearn 的 MiniBatchSparsePCA 包来尝试复制该应用程序,尽管我的输出是一个零矩阵。
数据
我的数据来自在 Stata 中清理的一项调查。它是一个包含 386 个答案的向量;这是句子。
我的尝试
# IMPORT LIBRARIES #
####################################
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
from sklearn import decomposition
####################################
# USE SKLEARN TO IMPORT STATA DATA. #
# Data comes from a survey, which was cleaned using Stata.
####################################
data_source = "/Users/****/q19_free_text.dta"
raw_data = pd.read_stata(data_source) #Reading in the data from a Stata file.
text_data = raw_data.iloc[:,1] #Cleaning out Observation ID number.
text_data.shape # Out[268]: (368, ) - There are 368 text (sentence) answers.
####################################
# Term Frequency – Inverse Document- Word Frequency
####################################
vectorizer = TfidfVectorizer(sublinear_tf=True, max_df=0.5,stop_words='english')
X_train = vectorizer.fit_transform(text_data)
spca = decomposition.MiniBatchSparsePCA(n_components=2, alpha=0.5)
spca.fit(X_train)
#TypeError: A sparse matrix was passed, but dense data is required. Use X.toarray() to convert to a dense numpy array.
X_train2 = X_train.toarray() #Trying with a dense array...
spca.fit(X_train2)
components = spca.components_
print(components) #Out: [[ 0. 0. 0. ..., 0. 0. 0.]
# [ 0. 0. 0. ..., 0. 0. 0.]]
components.shape #Out: (2, 916)
# Empty output!
其他说明
我使用这些来源编写了上面的代码:
最佳答案
(...) to do something similar to that which is done in section 4.1 in the paper linked. There they 'summarize' a text corpus by using SPCA and the output is K components, where each component is a list of words (or, features).
如果我理解正确的话,你会问如何检索组件的单词。
您可以通过检索组件中非零条目的索引来完成此操作(在组件
上使用适当的numpy
代码)。然后使用vectorizer.vocabulary_
您可以找出在您的组件中找到了哪些索引(单词/标记)。
参见this notebook示例实现(我使用了 20 个新闻组数据集)。
关于machine-learning - 文本数据上的 MiniBatchSparsePCA,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48034724/