python - 在 h2o AutoML 上检索 h2o AutoML 的交叉验证性能 (AUC) for holdout dataset

标签 python machine-learning cross-validation h2o automl

我正在使用默认交叉验证 (nfolds=5) 使用 h2o AutoML 训练二元分类模型。我需要获得每个保留折叠的 AUC 分数,以便计算可变性。

这是我使用的代码:

h2o.init()

prostate = h2o.import_file("https://h2o-public-test-data.s3.amazonaws.com/smalldata/prostate/prostate.csv")
# convert columns to factors
prostate['CAPSULE'] = prostate['CAPSULE'].asfactor()
prostate['RACE'] = prostate['RACE'].asfactor()
prostate['DCAPS'] = prostate['DCAPS'].asfactor()
prostate['DPROS'] = prostate['DPROS'].asfactor()

# set the predictor and response columns
predictors = ["AGE", "RACE", "VOL", "GLEASON"]
response_col = "CAPSULE"

# split into train and testing sets
train, test = prostate.split_frame(ratios = [0.8], seed = 1234)


aml = H2OAutoML(seed=1, max_runtime_secs=100, exclude_algos=["DeepLearning", "GLM"],
                    nfolds=5, keep_cross_validation_predictions=True)

aml.train(predictors, response_col, training_frame=prostate)

leader = aml.leader

我检查 leader 不是 StackedEnsamble 模型(验证指标不可用)。无论如何,我无法检索到五个 AUC 分数。

知道怎么做吗?

最佳答案

这是它是如何完成的:

import h2o
from h2o.automl import H2OAutoML

h2o.init()

# import prostate dataset
prostate = h2o.import_file("https://h2o-public-test-data.s3.amazonaws.com/smalldata/prostate/prostate.csv")
# convert columns to factors
prostate['CAPSULE'] = prostate['CAPSULE'].asfactor()
prostate['RACE'] = prostate['RACE'].asfactor()
prostate['DCAPS'] = prostate['DCAPS'].asfactor()
prostate['DPROS'] = prostate['DPROS'].asfactor()

# set the predictor and response columns
predictors = ["AGE", "RACE", "VOL", "GLEASON"]
response_col = "CAPSULE"

# split into train and testing sets
train, test = prostate.split_frame(ratios = [0.8], seed = 1234)

# run AutoML for 100 seconds
aml = H2OAutoML(seed=1, max_runtime_secs=100, exclude_algos=["DeepLearning", "GLM"],
                    nfolds=5, keep_cross_validation_predictions=True)
aml.train(x=predictors, y=response_col, training_frame=prostate)

# Get the leader model
leader = aml.leader

这里有一个关于交叉验证 AUC 的注意事项——H2O 目前存储了 CV AUC 的两个计算。一个是聚合版本(采用聚合 CV 预测的 AUC),另一个是交叉验证 AUC 的“真实”定义(来自 k 折交叉验证的 k AUC 的平均值)。后者存储在一个对象中,该对象还包含各个折叠 AUC 以及折叠之间的标准差。

如果您想知道我们为什么这样做,有一些历史和技术原因导致我们有两个版本,以及一个仅对每个报告开放的 ticket 后者。

第一个是执行此操作时获得的结果(以及 AutoML 排行榜上显示的结果)。

# print CV AUC for leader model
print(leader.model_performance(xval=True).auc())

如果您想要折叠 AUC 以便计算或查看它们的均值和变异性(标准差),您可以通过查看此处来实现:

# print CV metrics summary
leader.cross_validation_metrics_summary()

输出:

Cross-Validation Metrics Summary:
             mean        sd           cv_1_valid    cv_2_valid    cv_3_valid    cv_4_valid    cv_5_valid
-----------  ----------  -----------  ------------  ------------  ------------  ------------  ------------
accuracy     0.71842104  0.06419111   0.7631579     0.6447368     0.7368421     0.7894737     0.65789473
auc          0.7767409   0.053587236  0.8206676     0.70905924    0.7982079     0.82538515    0.7303846
aucpr        0.6907578   0.0834025    0.78737605    0.7141305     0.7147677     0.67790955    0.55960524
err          0.28157896  0.06419111   0.23684211    0.35526314    0.2631579     0.21052632    0.34210527
err_count    21.4        4.8785243    18.0          27.0          20.0          16.0          26.0
---          ---         ---          ---           ---           ---           ---           ---
precision    0.61751753  0.08747421   0.675         0.5714286     0.61702126    0.7241379     0.5
r2           0.20118153  0.10781976   0.3014902     0.09386432    0.25050205    0.28393403    0.07611712
recall       0.84506994  0.08513061   0.84375       0.9142857     0.9354839     0.7241379     0.8076923
rmse         0.435928    0.028099842  0.41264254    0.47447023    0.42546       0.41106534    0.4560018
specificity  0.62579334  0.15424488   0.70454544    0.41463414    0.6           0.82978725    0.58

See the whole table with table.as_data_frame()

这是排行榜的样子(存储汇总的 CV AUC)。在这种情况下,由于数据非常小(300 行),因此两个报告的 CV AUC 值之间存在明显差异,但是对于较大的数据集,它们应该是更接近的估计值。

# print the whole Leaderboard (all CV metrics for all models)
lb = aml.leaderboard
print(lb)

这将打印排行榜的顶部:

model_id                                                  auc    logloss     aucpr    mean_per_class_error      rmse       mse
---------------------------------------------------  --------  ---------  --------  ----------------------  --------  --------
XGBoost_grid__1_AutoML_20200924_200634_model_2       0.769716   0.565326  0.668827                0.290806  0.436652  0.190665
GBM_grid__1_AutoML_20200924_200634_model_4           0.762993   0.56685   0.666984                0.279145  0.437634  0.191524
XGBoost_grid__1_AutoML_20200924_200634_model_9       0.762417   0.570041  0.645664                0.300121  0.440255  0.193824
GBM_grid__1_AutoML_20200924_200634_model_6           0.759912   0.572651  0.636713                0.30097   0.440755  0.194265
StackedEnsemble_BestOfFamily_AutoML_20200924_200634  0.756486   0.574461  0.646087                0.294002  0.441413  0.194845
GBM_grid__1_AutoML_20200924_200634_model_7           0.754153   0.576821  0.641462                0.286041  0.442533  0.195836
XGBoost_1_AutoML_20200924_200634                     0.75411    0.584216  0.626074                0.289237  0.443911  0.197057
XGBoost_grid__1_AutoML_20200924_200634_model_3       0.753347   0.57999   0.629876                0.312056  0.4428    0.196072
GBM_grid__1_AutoML_20200924_200634_model_1           0.751706   0.577175  0.628564                0.273603  0.442751  0.196029
XGBoost_grid__1_AutoML_20200924_200634_model_8       0.749446   0.576686  0.610544                0.27844   0.442314  0.195642

[28 rows x 7 columns]

关于python - 在 h2o AutoML 上检索 h2o AutoML 的交叉验证性能 (AUC) for holdout dataset,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64032018/

相关文章:

python-3.x - (无法将字符串转换为 float )使用 knn 算法时出错

python - 尝试 kfold cv 时出现类型错误 : only integer scalar arrays can be converted to a scalar index ,

python - XGboost 的 RandomizedSearchCV、不平衡数据集和最佳迭代次数 (n_iter)

python - Django 无法使用 "python manage.py syncdb"命令创建默认表

python - 使用 python StandardScaler 进行特征缩放会产生负值

Python 3,如何将字符串转换为 "iso-8859-1"以在html中使用

algorithm - 算法训练阶段的健全性检查

matlab - 为什么我们需要在图像分类的 multiSVM 方法中进行交叉验证?

仅当表达式的值不是 None 时才返回表达式的 Python 语法

python - 从 Github 安装 Python 包