split - 为什么我们应该在将 StratifiedKFold() 作为 GridSearchCV 的参数传递时调用 split() 函数?

标签 split cross-validation grid-search gridsearchcv k-fold

我想做什么?

我正在尝试在 GridSearchCV() 中使用 StratifiedKFold()

那么,是什么让我感到困惑?

当我们使用 K 折交叉验证时,我们只是在 GridSearchCV() 中传递 CV 的数量,如下所示。

grid_search_m = GridSearchCV(rdm_forest_clf, param_grid, cv=5, scoring='f1', return_train_score=True, n_jobs=2)

然后,当我需要使用 StratifiedKFold() 时,我认为过程应该保持不变。即,仅将拆分数 - StratifiedKFold(n_splits=5) 设置为 cv

grid_search_m = GridSearchCV(rdm_forest_clf, param_grid, cv=StratifiedKFold(n_splits=5), scoring='f1', return_train_score=True, n_jobs=2)

但是this answer

whatever the cross validation strategy used, all that is needed is to provide the generator using the function split, as suggested:

kfolds = StratifiedKFold(5)
clf = GridSearchCV(estimator, parameters, scoring=qwk, cv=kfolds.split(xtrain,ytrain))
clf.fit(xtrain, ytrain)

此外,this question的答案之一也建议这样做。这意味着,他们建议在使用 GridSearchCV() 期间调用拆分函数:StratifiedKFold(n_splits=5).split(xtrain,ytrain)。但是,我发现调用 split() 和不调用 split() 给我的 f1 分数相同。

因此,我的问题

  • 我不明白为什么我们需要在分层 K 折叠期间调用 split() 函数作为 我们不需要在 K Fold CV 期间做这类事情。

  • 如果调用 split() 函数,GridSearchCV() 将如何作为 Split() 函数工作 returns training and testing data set indices ?也就是说,我想知道 GridSearchCV() 将如何使用这些索引?

最佳答案

GridSearchCV 基本上很聪明,可以为该 cv 参数采用多个选项 - 数字、拆分索引的迭代器或具有拆分函数的对象。可以看看代码here , 复制如下。

cv = 5 if cv is None else cv
if isinstance(cv, numbers.Integral):
    if (classifier and (y is not None) and
            (type_of_target(y) in ('binary', 'multiclass'))):
        return StratifiedKFold(cv)
    else:
        return KFold(cv)

if not hasattr(cv, 'split') or isinstance(cv, str):
    if not isinstance(cv, Iterable) or isinstance(cv, str):
        raise ValueError("Expected cv as an integer, cross-validation "
                         "object (from sklearn.model_selection) "
                         "or an iterable. Got %s." % cv)
    return _CVIterableWrapper(cv)

return cv  # New style cv objects are passed without any modification

基本上,如果您不传递任何内容,它会使用带有 5 的 KFold。它也足够聪明,可以自动使用 StratifedKFold,如果它是一个分类问题并且目标是二元/多类。

如果您传递一个带有拆分函数的对象,它只会使用它。如果您不传递它们中的任何一个,而是传递一个可迭代对象,它会假定这是一个可迭代的拆分索引并为您包装它。

因此,在您的情况下,假设这是一个二元/多类目标的分类问题,以下所有内容都会给出完全相同的结果/拆分 - 您使用哪一个都没有关系!

cv=5
cv=StratifiedKFold(5)
cv=StratifiedKFold(5).split(xtrain,ytrain)

关于split - 为什么我们应该在将 StratifiedKFold() 作为 GridSearchCV 的参数传递时调用 split() 函数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62174112/

相关文章:

c# - 如何在一行中将字符串拆分和修剪成多个部分?

java - 使用 Split 方法创建分词器

python - 具有分层交叉验证的多个性能指标

python - 网格搜索中的 H2OResponseError 获取网格排序

python - 拟合 sklearn GridSearchCV 模型

python - 在 Scikit-learn 中使用 Smote 和 Gridsearchcv

python - 在python中分割并用另一个相同长度的字符串替换一个字符串

java - StringTokenizer delimit 一次

validation - 交叉验证是如何实现的?

python - predict_proba 用于交叉验证模型