python - 使用 sklearn 进行网格搜索的显式(预定义)验证集

标签 python validation scikit-learn cross-validation

我有一个数据集,之前已分为 3 组:训练、验证和测试。为了比较不同算法的性能,必须按照给定的方式使用这些集合。

我现在想使用验证集优化我的 SVM 的参数。但是,我找不到如何将验证集显式输入 sklearn.grid_search.GridSearchCV()。下面是我之前用于在训练集上进行 K 折交叉验证的一些代码。但是,对于这个问题,我需要使用给定的验证集。我该怎么做?

from sklearn import svm, cross_validation
from sklearn.grid_search import GridSearchCV

# (some code left out to simplify things)

skf = cross_validation.StratifiedKFold(y_train, n_folds=5, shuffle = True)
clf = GridSearchCV(svm.SVC(tol=0.005, cache_size=6000,
                             class_weight=penalty_weights),
                     param_grid=tuned_parameters,
                     n_jobs=2,
                     pre_dispatch="n_jobs",
                     cv=skf,
                     scoring=scorer)
clf.fit(X_train, y_train)

最佳答案

使用 PredefinedSplit

ps = PredefinedSplit(test_fold=your_test_fold)

然后在GridSearchCV

中设置cv=ps

test_fold : “array-like, shape (n_samples,)

test_fold[i] gives the test set fold of sample i. A value of -1 indicates that the corresponding sample is not part of any test set folds, but will instead always be put into the training fold.

另见 here

when using a validation set, set the test_fold to 0 for all samples that are part of the validation set, and to -1 for all other samples.

关于python - 使用 sklearn 进行网格搜索的显式(预定义)验证集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/31948879/

相关文章:

python - 为什么我的实体没有显示在开发服务器上,即使它们显示在仪表板上?

c# - ASP.NET MVC 数据注释验证 ErrorMessageResourceType

python - 重新调整 3d numpy 数组中的值

python - GridSearchCV - 类型错误 : an integer is required

python - 无法导入名称 ‘etree’

python - 遇到两个单独的关键字后跳出for循环

python - worker_concurrency 配置对 celery 无效

regex - perl 中的 4 位数字验证

java - 你能限制注释目标是某个类的子类吗?

python - 精度比 gridsearchCV 低