类似于How to pass a parameter to only one part of a pipeline object in scikit learn?我只想将参数传递给管道的一部分。通常,它应该像这样正常工作:
estimator = XGBClassifier()
pipeline = Pipeline([
('clf', estimator)
])
并像这样执行
pipeline.fit(X_train, y_train, clf__early_stopping_rounds=20)
但它失败了:
/usr/local/lib/python3.5/site-packages/sklearn/pipeline.py in fit(self, X, y, **fit_params)
114 """
115 Xt, yt, fit_params = self._pre_transform(X, y, **fit_params)
--> 116 self.steps[-1][-1].fit(Xt, yt, **fit_params)
117 return self
118
/usr/local/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py in fit(self, X, y, sample_weight, eval_set, eval_metric, early_stopping_rounds, verbose)
443 early_stopping_rounds=early_stopping_rounds,
444 evals_result=evals_result, obj=obj, feval=feval,
--> 445 verbose_eval=verbose)
446
447 self.objective = xgb_options["objective"]
/usr/local/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/training.py in train(params, dtrain, num_boost_round, evals, obj, feval, maximize, early_stopping_rounds, evals_result, verbose_eval, learning_rates, xgb_model, callbacks)
201 evals=evals,
202 obj=obj, feval=feval,
--> 203 xgb_model=xgb_model, callbacks=callbacks)
204
205
/usr/local/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/training.py in _train_internal(params, dtrain, num_boost_round, evals, obj, feval, xgb_model, callbacks)
97 end_iteration=num_boost_round,
98 rank=rank,
---> 99 evaluation_result_list=evaluation_result_list))
100 except EarlyStopException:
101 break
/usr/local/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/callback.py in callback(env)
196 def callback(env):
197 """internal function"""
--> 198 score = env.evaluation_result_list[-1][1]
199 if len(state) == 0:
200 init(env)
IndexError: list index out of range
鉴于
estimator.fit(X_train, y_train, early_stopping_rounds=20)
工作得很好。
最佳答案
对于早期停止轮次,您必须始终指定参数 eval_set 给出的验证集。以下是如何修复代码中的错误。
pipeline.fit(X_train, y_train, clf__early_stopping_rounds=20, clf__eval_set=[(test_X, test_y)])
关于python - Sklearn 将 fit() 参数传递给管道中的 xgboost,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40329576/