python - 学习 : Cross validation for grouped data

标签 python scikit-learn cross-validation

我正在尝试对分组数据实现交叉验证方案。我希望使用 GroupKFold 方法,但我一直收到错误消息。我究竟做错了什么? 代码(与我使用的代码略有不同——我有不同的数据,所以我有一个更大的 n_splits,但其他一切都是一样的)

from sklearn import metrics
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import GroupKFold
from sklearn.grid_search import GridSearchCV
from xgboost import XGBRegressor
#generate data
x=np.array([0,1,2,3,4,5,6,7,8,9,10,11,12,13])
y= np.array([1,2,3,4,5,6,7,1,2,3,4,5,6,7])
group=np.array([1,0,1,1,2,2,2,1,1,1,2,0,0,2)]
#grid search
gkf = GroupKFold( n_splits=3).split(x,y,group)
subsample = np.arange(0.3,0.5,0.1)
param_grid = dict( subsample=subsample)
rgr_xgb = XGBRegressor(n_estimators=50)
grid_search = GridSearchCV(rgr_xgb, param_grid, cv=gkf, n_jobs=-1)
result = grid_search.fit(x, y)

错误:

Traceback (most recent call last):

File "<ipython-input-143-11d785056a08>", line 8, in <module>
result = grid_search.fit(x, y)

File "/home/student/anaconda/lib/python3.5/site-packages/sklearn/grid_search.py", line 813, in fit
return self._fit(X, y, ParameterGrid(self.param_grid))

 File "/home/student/anaconda/lib/python3.5/site-packages/sklearn/grid_search.py", line 566, in _fit
n_folds = len(cv)

TypeError: object of type 'generator' has no len()

换行

gkf = GroupKFold( n_splits=3).split(x,y,group)

gkf = GroupKFold( n_splits=3)

也不行。然后错误信息是:

'GroupKFold' object is not iterable

最佳答案

split GroupKFold 的函数 yield 训练和测试指标一次一对。您应该对拆分值调用 list 以将它们全部放入列表中,以便可以计算长度:

gkf = list(GroupKFold( n_splits=3).split(x,y,group))

关于python - 学习 : Cross validation for grouped data,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40369987/

相关文章:

python - 如何在 Python 中将二进制文件读取为十六进制?

pandas - 使用交叉验证分数获得零分

Python Scikit - 调用 sklearn.metrics. precision_recall_curve 时输入形状错误

r - 在 caret::train 函数中使用 bagImpute 预处理时出现缺失值错误

machine-learning - 缩放决策树中的数据会改变我的结果吗?

python - Pygame绘制抗锯齿粗线

python - 获取另一个应用程序的所有快捷方式的列表

python - 使用 Numpy 更有效地改变色调

python - 我无法为 python 3.7 安装 Scikit-learn

machine-learning - 如何使用 cross_val_score 来拟合我的测试数据?