r - Spark ML 管道 Logistic 回归产生的预测比 R GLM 差得多

标签 r scala apache-spark apache-spark-ml

我使用 ML PipeLine 来运行逻辑回归模型,但由于某些原因,我得到的结果比 R 最差。我做了一些研究,我发现与此问题相关的唯一帖子是 this 。看来Spark Logistic Regression returns models that minimize loss function而 R glm 函数使用最大似然。 Spark 模型仅正确预测了 71.3% 的记录,而 R 可以正确预测 95.55% 的情况。我想知道我是否在设置上做错了什么,以及是否有办法改进预测。下面是我的 Spark 代码和 R 代码-

Spark 代码

partial model_input  
label,AGE,GENDER,Q1,Q2,Q3,Q4,Q5,DET_AGE_SQ  
1.0,39,0,0,1,0,0,1,31.55709342560551  
1.0,54,0,0,0,0,0,0,83.38062283737028  
0.0,51,0,1,1,1,0,0,35.61591695501733



def trainModel(df: DataFrame): PipelineModel = {  
  val lr  = new LogisticRegression().setMaxIter(100000).setTol(0.0000000000000001)  
  val pipeline = new Pipeline().setStages(Array(lr))  
  pipeline.fit(df)  
}

val meta =  NominalAttribute.defaultAttr.withName("label").withValues(Array("a", "b")).toMetadata

val assembler = new VectorAssembler().
  setInputCols(Array("AGE","GENDER","DET_AGE_SQ",
 "QA1","QA2","QA3","QA4","QA5")).
  setOutputCol("features")

val model = trainModel(model_input)
val pred= model.transform(model_input)  
pred.filter("label!=prediction").count

R代码

lr <- model_input %>% glm(data=., formula=label~ AGE+GENDER+Q1+Q2+Q3+Q4+Q5+DET_AGE_SQ,
          family=binomial)
pred <- data.frame(y=model_input$label,p=fitted(lr))
table(pred $y, pred $p>0.5)

如果您需要任何其他信息,请随时告诉我。谢谢!

编辑 9/18/2015 我尝试过增加最大迭代次数并显着降低容差。不幸的是,它并没有改善预测。模型似乎收敛到局部最小值而不是全局最小值。

最佳答案

It seems that Spark Logistic Regression returns models that minimize loss function while R glm function uses maximum likelihood.

损失函数的最小化几乎是线性模型的定义,并且 glmml.classification.LogisticRegression 在这里没有什么不同。两者之间的根本区别在于实现方式。

ML/MLlib 中的所有线性模型均基于 Gradient descent 的某些变体。使用此方法生成的模型的质量因具体情况而异,并且取决于梯度下降和正则化参数。

另一方面,R 计算精确的解决方案,考虑到其时间复杂度,它不太适合大型数据集。

正如我上面提到的,使用 GS 生成的模型的质量取决于输入参数,因此改进模型的典型方法是执行超参数优化。不幸的是,与 MLlib 相比,ML 版本相当有限,但对于初学者来说,您可以增加迭代次数。

关于r - Spark ML 管道 Logistic 回归产生的预测比 R GLM 差得多,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/32642789/

相关文章:

scala - FlinkML:加入 LabeledVector 的数据集不起作用

java - 将 JavaRDD<ArrayList<T>> 转换为 JavaRDD<T>

R 包 : writing internal data, 但不是一次全部

r - R igraph中添加节点和删除指定边

r - 确定模型公式是否只有截距的最简单方法

scala - 部分应用类型参数

r - ggplot2 中的 panel.border 在 Cairo PDF 设备的图的底部和右侧绘制粗线

Scala Actor 消息定义

scala - 更改 DataFrame 中嵌套列的值

java - Spark中读数据库的执行时间