python - sklearn.ensemble.AdaBoostClassifier 不能接受 SVM 作为 base_estimator?

标签 python machine-learning scikit-learn ensemble-learning

我正在做一个文本分类任务。现在我想使用 ensemble.AdaBoostClassifierLinearSVC 作为 base_estimator。但是,当我尝试运行代码时

clf = AdaBoostClassifier(svm.LinearSVC(),n_estimators=50, learning_rate=1.0,    algorithm='SAMME.R')
clf.fit(X, y)

发生错误。 TypeError:AdaBoostClassifier with algorithm='SAMME.R' 要求弱学习器支持使用 predict_proba 方法计算类别概率

第一个问题是 svm.LinearSVC() 不能计算类别概率吗?如何让它计算概率?

然后我更改参数 algorithm 并再次运行代码。

clf = AdaBoostClassifier(svm.LinearSVC(),n_estimators=50, learning_rate=1.0, algorithm='SAMME')
clf.fit(X, y)

这次 TypeError: fit() got an unexpected keyword argument 'sample_weight' 发生了。正如AdaBoostClassifier中所说, 样本权重。如果为 None,则样本权重初始化为 1/n_samples。 即使我将整数分配给 n_samples,也会发生错误。

第二个问题是 n_samples 是什么意思?如何解决这个问题呢?

希望有人能帮助我。

根据@jme 的评论,然而,在尝试之后

clf = AdaBoostClassifier(svm.SVC(kernel='linear',probability=True),n_estimators=10,  learning_rate=1.0, algorithm='SAMME.R')
clf.fit(X, y)

程序无法获取结果,服务器上使用的内存保持不变。

第三个问题是如何让 AdaBoostClassifierSVC 一起作为 base_estimator 工作?

最佳答案

正确答案取决于您要查找的内容。 LinearSVC 无法预测类别概率(AdaBoostClassifier 使用的默认算法需要)并且不支持 sample_weight。

您应该知道,支持向量机不会名义上预测类别概率。它们是使用 Platt 缩放(或 Platt 缩放在多类情况下的扩展)计算的,这是一种存在已知问题的技术。如果您需要较少的“人工”类别概率,则 SVM 可能不是合适的选择。

话虽如此,对于您的问题,我相信最令人满意的答案是 Graham 给出的答案。也就是说,

from sklearn.svm import SVC
from sklearn.ensemble import AdaBoostClassifier

clf = AdaBoostClassifier(SVC(probability=True, kernel='linear'), ...)

您还有其他选择。您可以使用带有铰链损失函数的 SGDClassifier 并将 AdaBoostClassifier 设置为使用 SAMME 算法(不需要 predict_proba 函数,但需要支持 sample_weight):

from sklearn.linear_model import SGDClassifier

clf = AdaBoostClassifier(SGDClassifier(loss='hinge'), algorithm='SAMME', ...)

如果您想使用为 AdaBoostClassifier 提供的默认算法,也许最好的答案是使用对类别概率具有 native 支持的分类器,例如 Logistic 回归。您可以使用 scikit.linear_model.LogisticRegression 或使用带有对数损失函数的 SGDClassifier 来执行此操作,如 Kris 提供的代码中所用。

希望对您有所帮助,如果您对什么是 Platt 缩放感到好奇,check out the original paper by John Platt here .

关于python - sklearn.ensemble.AdaBoostClassifier 不能接受 SVM 作为 base_estimator?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/27107205/

相关文章:

r - 分配损失时如何解释 rpart(方法 ="class")的 xerror

machine-learning - 用于分类/多类分类的梯度提升树的弱学习器

python - matplotlib:为什么绘制历史记录会在python中导致IndexError?

python - emacs `python-shell-send-defun` 跳过缓冲区中的第一行

python - 将 virtualenv 与 eclipse 一起使用

python - Python 中的方法记录

python - 使用网格搜索的交叉验证返回比默认更差的结果

algorithm - Elastic x-pack 插件使用的机器学习算法

python - 如何使用带有 GridSearchCV 对象的 TimeSeriesSplit 来调整 scikit-learn 中的模型?

python - pd.get_dummies 数据帧在 Sparse = True 时与 Sparse = False 时大小相同