python - 使用 LightGBM 进行多类分类

标签 python machine-learning predict multiclass-classification lightgbm

我正在尝试使用 Python 中的 LightGBM 为多类分类问题(3 类)建立分类器模型。我使用了以下参数。

params = {'task': 'train',
    'boosting_type': 'gbdt',
    'objective': 'multiclass',
    'num_class':3,
    'metric': 'multi_logloss',
    'learning_rate': 0.002296,
    'max_depth': 7,
    'num_leaves': 17,
    'feature_fraction': 0.4,
    'bagging_fraction': 0.6,
    'bagging_freq': 17}

数据集的所有分类特征均使用LabelEncoder进行标签编码。我在使用 eartly_stopping 运行 cv 后训练了模型,如下所示。

lgb_cv = lgbm.cv(params, d_train, num_boost_round=10000, nfold=3, shuffle=True, stratified=True, verbose_eval=20, early_stopping_rounds=100)

nround = lgb_cv['multi_logloss-mean'].index(np.min(lgb_cv['multi_logloss-mean']))
print(nround)

model = lgbm.train(params, d_train, num_boost_round=nround)

训练后,我用这样的模型进行了预测,

preds = model.predict(test)
print(preds)             

我得到了一个嵌套数组作为这样的输出。

[[  7.93856847e-06   9.99989550e-01   2.51164967e-06]
 [  7.26332978e-01   1.65316511e-05   2.73650491e-01]
 [  7.28564308e-01   8.36756769e-06   2.71427325e-01]
 ..., 
 [  7.26892634e-01   1.26915179e-05   2.73094674e-01]
 [  5.93217601e-01   2.07172044e-04   4.06575227e-01]
 [  5.91722491e-05   9.99883828e-01   5.69994435e-05]]

由于 preds 中的每个列表都代表类概率,因此我使用 np.argmax() 来查找这样的类..

predictions = []

for x in preds:
    predictions.append(np.argmax(x))

在分析预测时,我发现我的预测仅包含 2 个类别 - 0 和 1。类别 2 是训练集中的第二大类别,但在预测中找不到它。在评估结果时,它准确率约为 78%

那么,为什么我的模型在任何情况下都没有预测 2 类。?我使用的参数有问题吗?

这不是解释模型做出的预测的正确方法吗?我应该对参数进行任何更改吗??

最佳答案

尝试通过交换类别 0 和 2 并重新运行训练和预测过程来进行故障排除。

如果新的预测仅包含类别 1 和类别 2(很可能考虑到您提供的数据):

  • 分类器可能还没有学到第三类;也许它的特征与较大类的特征重叠,并且分类器默认为较大类,以便最小化目标函数。尝试提供平衡的训练集(每个类的样本数量相同)并重试。

如果新的预测确实包含所有 3 个类别:

  • 您的代码中某处出了问题。需要更多信息来确定到底出了什么问题。

希望这有帮助。

关于python - 使用 LightGBM 进行多类分类,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47370240/

相关文章:

python - 在Python3中从zip存档中提取特定文件夹的内容

Python 单元测试忽略 numpy

python - 获取one-hot编码的H2OFrame

machine-learning - LightGBM 中的 Bagging 是如何工作的

使用元类的 Python 类设计

python - numpy std计算: TypeError: cannot perform reduce with flexible type

python - 用于确定 k 均值中的 k 的 k 折交叉验证?

r - 计算预测值时发出警告

r - 如何在 Terra::Predict 中格式化 'const' 参数?

r - 如何使用预测计算 R 中预测数据的标准误差