我正在使用由 160 万条推文组成的情感140 数据集来训练和分析 python scikit-learn 库中不同分类器的准确性。我使用以下代码片段将推文矢量化为特征向量,然后将其输入分类器。
vectorizer = CountVectorizer(max_features = 2000)
train_set = (vectorizer.fit_transform(trainX)).toarray()
此后,我训练我的分类器
对象,其中包括GaussianNB()
、MultinomialNB()
、BernoulliNB()
>、LogisticRegression()
、LinearSVC()
和 RandomForestClassifier()
使用以下代码片段:
classifier.fit(train_vectors,trainy)
但是,在使用 trainset
的 toarray()
函数将矢量化器的转换集转换为 numpy 数组时,我发现该程序消耗了大量资源内存(大约 4-5 GB)仅用于 100k 个示例,每个示例的特征向量大小为 2000,即 100,000x2000 特征向量。
这是我的系统所能做到的最大值,因为我只有 8GB RAM。有人可以建议我如何继续,以便能够通过可能修改代码来使用可用内存来训练整个 1.6M 的训练数据集。如果我尝试使用上面的代码,将需要大约 72 GB 的 RAM,这是不可行的。
我还了解到,有一些规定可以用训练集的一小部分迭代地逐步训练某些分类器。 MultinomialNB()
和 BernoulliNB()
等分类器对此有规定(使用 partial_fit
),但我也使用的其他分类器不这样做't,所以这不是一个选择。
最佳答案
问题是,您首先想要实现什么目标?我问的原因是,由于问题的本质,矢量化文本具有大量维度。此外,max_features=2000
也不会让您在文本分类方面获得足够的性能。
长话短说:您提到的大多数分类器都使用稀疏向量,除了 GaussianNB
,这可以通过以下方式轻松验证:
from sklearn.naive_bayes import GaussianNB, MultinomialNB, BernoulliNB
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.ensemble import GradientBoostingClassifier
from scipy.sparse import csr_matrix
from sklearn.datasets import load_digits
digits = load_digits()
X, y = digits.data, digits.target
for CLF in [GaussianNB, MultinomialNB, BernoulliNB, LogisticRegression, LinearSVC, GradientBoostingClassifier]:
print(CLF.__name__, end='')
try:
CLF().fit(csr_matrix(X), y == 0)
print(' PASS')
except TypeError:
print(' FAIL')
哪些输出:
GaussianNB FAIL
MultinomialNB PASS
BernoulliNB PASS
LogisticRegression PASS
LinearSVC PASS
GradientBoostingClassifier PASS
我建议,您只需从列表中删除 GaussianNB
并使用支持稀疏向量的分类器即可。您至少应该能够在 8g 限制内使用更多样本。
另请参阅this issue对于 scikit-learn,引用 Jake Vanderplas:
One reason sparse inputs are not implemented in
GaussianNB
is that very sparse data almost certainly does not meet the assumptions of the algorithm – when the bulk of the values are zero, a simple Gaussian is not a good fit to the data, and will almost never lead to a useful classification.
关于python - 如何使用 scikit-learn 训练/升级非常大的数据集?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36180659/