lightgbm - 如何正确保存mlr3 lightgbm模型?

标签 lightgbm mlr3

我有一些以下代码。我在保存训练模型时遇到错误。 只有当我使用 lightgbm 时才会出现错误。

library(mlr3)
library(mlr3pipelines)
library(mlr3extralearners)

data = tsk("german_credit")$data()
data = data[, c("credit_risk", "amount", "purpose", "age")]
task = TaskClassif$new("boston", backend = data, target = "credit_risk")

g = po("imputemedian") %>>%
  po("imputeoor") %>>%
  po("fixfactors") %>>%
  po("encodeimpact") %>>% 
  lrn("classif.lightgbm")

gl = GraphLearner$new(g)

gl$train(task)

# predict 
newdata <- data[1,]
gl$predict_newdata(newdata) 
saveRDS(gl, "gl.rds")
# read model from disk ----------------
gl <- readRDS("gl.rds")
newdata <- data[1,]

# error when predict ------------------
gl$predict_newdata(newdata)

最佳答案

lightgbm 使用特殊函数 saveread楷模。您必须在保存之前提取模型,并在加载后将其添加到图形学习器中。然而,这对于基准测试来说可能不切实际。我们会调查此事。

library(mlr3)
library(mlr3pipelines)
library(mlr3extralearners)
library(lightgbm)

data = tsk("german_credit")$data()
data = data[, c("credit_risk", "amount", "purpose", "age")]
task = TaskClassif$new("boston", backend = data, target = "credit_risk")

g = po("imputemedian") %>>%
  po("imputeoor") %>>%
  po("fixfactors") %>>%
  po("encodeimpact") %>>% 
  lrn("classif.lightgbm")

gl = GraphLearner$new(g)

gl$train(task)

# save model
saveRDS.lgb.Booster(gl$model$classif.lightgbm$model, "model.rda")

# save graph learner
saveRDS(gl, "gl.rda")

# load model
model = readRDS.lgb.Booster("model.rda")

# load graph learner
gl = readRDS("gl.rda")

# add model to graph learner
gl$state$model$classif.lightgbm$model = model

# predict
newdata <- data[1,]
gl$predict_newdata(newdata)

关于lightgbm - 如何正确保存mlr3 lightgbm模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69793716/

相关文章:

python - lgb.train ValueError : The truth value of an array with more than one element is ambiguous. 使用 a.any() 或 a.all()

python - 如何使用 lightgbm.cv 进行回归?

python - 如何使用 LightGBM 保存 GridSearchCV 每次迭代中的每个预测结果

progress-bar - mlr3 - 基准测试 : status messages are only displayed after full benchmark is completed

python - AUC 高,但数据不平衡导致预测不佳

lightgbm - 如何为我的自定义损失函数修改 lightgbm?

survival - MLR3生存分析: how to simultaneously perform feature selection & hyperparameter tuning together and get selected_features?

r - mlr3 正确设置并行化

r - 如何对新数据使用预测?