我想做什么?
我正在尝试在 GridSearchCV()
中使用 StratifiedKFold()
。
那么,是什么让我感到困惑?
当我们使用 K 折交叉验证时,我们只是在 GridSearchCV()
中传递 CV 的数量,如下所示。
grid_search_m = GridSearchCV(rdm_forest_clf, param_grid, cv=5, scoring='f1', return_train_score=True, n_jobs=2)
然后,当我需要使用 StratifiedKFold()
时,我认为过程应该保持不变。即,仅将拆分数 - StratifiedKFold(n_splits=5)
设置为 cv
。
grid_search_m = GridSearchCV(rdm_forest_clf, param_grid, cv=StratifiedKFold(n_splits=5), scoring='f1', return_train_score=True, n_jobs=2)
但是this answer说
whatever the cross validation strategy used, all that is needed is to provide the generator using the function split, as suggested:
kfolds = StratifiedKFold(5) clf = GridSearchCV(estimator, parameters, scoring=qwk, cv=kfolds.split(xtrain,ytrain)) clf.fit(xtrain, ytrain)
此外,this question的答案之一也建议这样做。这意味着,他们建议在使用 GridSearchCV()
期间调用拆分函数:StratifiedKFold(n_splits=5).split(xtrain,ytrain)
。但是,我发现调用 split()
和不调用 split()
给我的 f1 分数相同。
因此,我的问题
我不明白为什么我们需要在分层 K 折叠期间调用
split()
函数作为 我们不需要在 K Fold CV 期间做这类事情。如果调用
split()
函数,GridSearchCV()
将如何作为Split()
函数工作 returns training and testing data set indices ?也就是说,我想知道GridSearchCV()
将如何使用这些索引?
最佳答案
GridSearchCV 基本上很聪明,可以为该 cv 参数采用多个选项 - 数字、拆分索引的迭代器或具有拆分函数的对象。可以看看代码here , 复制如下。
cv = 5 if cv is None else cv
if isinstance(cv, numbers.Integral):
if (classifier and (y is not None) and
(type_of_target(y) in ('binary', 'multiclass'))):
return StratifiedKFold(cv)
else:
return KFold(cv)
if not hasattr(cv, 'split') or isinstance(cv, str):
if not isinstance(cv, Iterable) or isinstance(cv, str):
raise ValueError("Expected cv as an integer, cross-validation "
"object (from sklearn.model_selection) "
"or an iterable. Got %s." % cv)
return _CVIterableWrapper(cv)
return cv # New style cv objects are passed without any modification
基本上,如果您不传递任何内容,它会使用带有 5 的 KFold。它也足够聪明,可以自动使用 StratifedKFold,如果它是一个分类问题并且目标是二元/多类。
如果您传递一个带有拆分函数的对象,它只会使用它。如果您不传递它们中的任何一个,而是传递一个可迭代对象,它会假定这是一个可迭代的拆分索引并为您包装它。
因此,在您的情况下,假设这是一个二元/多类目标的分类问题,以下所有内容都会给出完全相同的结果/拆分 - 您使用哪一个都没有关系!
cv=5
cv=StratifiedKFold(5)
cv=StratifiedKFold(5).split(xtrain,ytrain)
关于split - 为什么我们应该在将 StratifiedKFold() 作为 GridSearchCV 的参数传递时调用 split() 函数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62174112/