python - XGBoost 成对设置 - python

标签 python ranking xgboost

在 XGBoost 中,我尝试了多种方法来使成对组与组集一起工作,但没有成功。以下代码在使用 set_group 时不起作用,但可以为 xgbTrain 注释掉 set_group

import xgboost
import pandas as pd
from xgboost import DMatrix,train

xgb_params ={    
    'booster' : 'gbtree',
    'eta': 0.1,
    'gamma' : 1.0 ,
    'min_child_weight' : 0.1,
    'objective' : 'rank:pairwise',
    'eval_metric' : 'merror',
    #'num_class': 3,  # 
    'max_depth' : 6,
    'num_round' : 4,
    'save_period' : 0 
}


n_group=2
n_choice=3    

#training dataset

dtrain=np.random.uniform(0,100,[n_group*n_choice,2])    
dtarget=np.array([np.random.choice([0,1,2],3,False) for i in range(n_group)]).flatten()
dgroup=np.array([np.repeat(i,3)for i in range(n_group)]).flatten()

xgbTrain = DMatrix(dtrain, label = dtarget)
xgbTrain =xgbTrain.set_group(dgroup)

#watchlist

dtrain_eval=np.random.uniform(0,100,[n_group*n_choice,2])        

xgbTrain_eval = DMatrix(dtrain_eval, label = dtarget)
#xgbTrain_eval =xgbTrain_eval .set_group(dgroup)

#test dataset

dtest=np.random.uniform(0,100,[n_group*n_choice,2])    
dtestgroup=np.array([np.repeat(i,3)for i in range(n_group)]).flatten()

xgbTest = DMatrix(dtest)
#xgbTest =xgbTest.set_group(dgroup)
evallist  = [(xgbTrain_eval, 'eval')]

rankModel = xgboost.train(params=xgb_params,dtrain=xgbTrain  )
print(rankModel.predict( xgbTest))

返回的错误似乎指向缺少 eval 数据,但甚至将 evals 指定为

 rankModel = xgboost.train(params=xgb_params,dtrain=xgbTrain,evals=evallist )

错误仍然存​​在。

请注意 num_class 被注释掉了,但直观上它的值应该是 3(这里对应于类的数量)或 2(在成对排名的情况下代表组的数量)?

任何帮助指出错误的地方?

(Xgboost 0.6)

最佳答案

一个错误: 我的杯子,set_group 不正确,应该是

     xgbTrain.set_group(dgroup)

不是

     xgbTrain =xgbTrain.set_group(dgroup)

解决方法:

set_group 中的数据应该只是每组每个项目的计数,每组一个项目。

      dgroup=np.array([n_choice for i in range(n_group)]).flatten()

成功了!

关于python - XGBoost 成对设置 - python,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49939323/

相关文章:

python - 如何提高pandas中mysql查询的处理速度

python - str没有附加属性错误

javascript - 对整数数组进行排序并显示输入值

arrays - 用少于排列长度的字节保存排列

python - 在 Ubuntu 上构建原生 webrtc

python - Flask-sqlalchemy 失去了与 MySQL 数据库的连接

python - 排名聚合: Merge local subrankings into global ranking

python - 为什么在部署我在本地运行良好的 Flask 应用程序时,Heroku 崩溃了(代码=H10)?

python - 无法加载 XGBoost 库 (libxgboost.so)

python - XGBoost特征重要性: How do I get original variable names after encoding