python - sklearn cross_val_score 的准确性低于手动交叉验证

标签 python python-3.x scikit-learn cross-validation

我正在研究一个文本分类问题,我是这样设置的(为了简洁起见,我省略了数据处理步骤,但它们会生成一个名为 data 的数据框包含 Xy 列):

import sklearn.model_selection as ms
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.ensemble import RandomForestClassifier

sim = Pipeline([('vec', TfidfVectorizer((analyzer="word", ngram_range=(1, 2))),
                ("rdf", RandomForestClassifier())])

现在我尝试通过在 2/3 的数据上训练它并在剩余的 1/3 上评分来验证这个模型,如下所示:

train, test = ms.train_test_split(data, test_size = 0.33)
sim.fit(train.X, train.y)
sim.score(test.X, test.y)
# 0.533333333333

我想对三个不同的测试集执行三次此操作,但使用 cross_val_score 给我的结果要低得多。

ms.cross_val_score(sim, data.X, data.y)
# [ 0.29264069  0.36729223  0.22977941]

据我所知,该数组中的每个分数都应该通过对 2/3 的数据进行训练并使用 sim.score 方法对剩余的 1/3 进行评分来生成。那么为什么它们都低得多呢?

最佳答案

我在写问题的过程中解决了这个问题,所以这里是:

cross_val_score 的默认行为是使用 KFoldStratifiedKFold 来定义折叠。默认情况下,两者都有参数 shuffle=False,因此不会从数据中随机提取折叠:

import numpy as np
import sklearn.model_selection as ms

for i, j in ms.KFold().split(np.arange(9)):
    print("TRAIN:", i, "TEST:", j)
TRAIN: [3 4 5 6 7 8] TEST: [0 1 2]
TRAIN: [0 1 2 6 7 8] TEST: [3 4 5]
TRAIN: [0 1 2 3 4 5] TEST: [6 7 8]

我的原始数据是按标签排列的,因此通过这种默认行为,我试图预测很多我在训练数据中没有看到的标签。如果我强制使用 KFold(我正在做分类,所以 StratifiedKFold 是默认值),这会更加明显:

ms.cross_val_score(sim, data.text, data.label, cv = ms.KFold())
# array([ 0.05530776,  0.05709188,  0.025     ])
ms.cross_val_score(sim, data.text, data.label, cv = ms.StratifiedKFold(shuffle = False))
# array([ 0.2978355 ,  0.35924933,  0.27205882])
ms.cross_val_score(sim, data.text, data.label, cv = ms.KFold(shuffle = True))
# array([ 0.51561106,  0.50579839,  0.51785714])
ms.cross_val_score(sim, data.text, data.label, cv = ms.StratifiedKFold(shuffle = True))
# array([ 0.52869565,  0.54423592,  0.55626715])

手工做事给了我更高的分数,因为 train_test_split 做的事情与 KFold(shuffle = True) 做的一样。

关于python - sklearn cross_val_score 的准确性低于手动交叉验证,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43688058/

相关文章:

python - 深度学习模型预测关键词点击次数

python - 从 numpy 数组获取指针以将图像发送到 C++

python - Pandas - GroupBy 2 列 - 无法重置索引

python - 列表切片并找到第二大值python

python - 检查某些文件夹中是否有任何图像重复的最高效(比我的更好)的方法?

machine-learning - 如何使用整个训练示例来估计 sklearn RandomForest 中的类概率

python - 将两个 boolean 列转换为 Pandas 中的类 ID

python - Apache spark 和 python lambda

python - Sklearn LogisticRegressionCV 的类似数组的输入

python - 基于Python中的键聚合字典列表上的值