scala - Spark DataFrame 过滤器无法按预期与随机一起工作

标签 scala apache-spark dataframe apache-spark-sql

这是我的数据框

df.groupBy($"label").count.show
+-----+---------+                                                               
|label|    count|
+-----+---------+
|  0.0|400000000|
|  1.0| 10000000|
+-----+---------+

我正在尝试使用以下命令对标签 == 0.0 的记录进行二次采样:

val r = scala.util.Random
val df2 = df.filter($"label" === 1.0 || r.nextDouble > 0.5) // keep 50% of 0.0

我的输出如下所示:

df2.groupBy($"label").count.show
+-----+--------+                                                                
|label|   count|
+-----+--------+
|  1.0|10000000|
+-----+--------+

最佳答案

r.nextDouble 是表达式中的常量,因此实际评估与您的意思有很大不同。根据实际采样值,它是:

scala> r.setSeed(0)

scala> $"label" === 1.0 || r.nextDouble > 0.5
res0: org.apache.spark.sql.Column = ((label = 1.0) OR true)

scala> r.setSeed(4096)

scala> $"label" === 1.0 || r.nextDouble > 0.5
res3: org.apache.spark.sql.Column = ((label = 1.0) OR false)

简化后就是:

true

(保留所有记录)或

label = 1.0 

(仅保留您观察到的情况)。

要生成随机数,您应该使用 corresponding SQL function

scala> import org.apache.spark.sql.functions.rand
import org.apache.spark.sql.functions.rand

scala> $"label" === 1.0 || rand > 0.5
res1: org.apache.spark.sql.Column = ((label = 1.0) OR (rand(3801516599083917286) > 0.5))

虽然Spark已经提供了分层采样工具:

df.stat.sampleBy(
  "label",  // column
  Map(0.0 -> 0.5, 1.0 -> 1.0),  // fractions
  42 // seed 
)

关于scala - Spark DataFrame 过滤器无法按预期与随机一起工作,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54769360/

相关文章:

scala - 一对行上的 Spark Dataframe 滑动窗口

读取以竖线和逗号分隔的文件 : |column1|, |column2|

scala - Map到TreeMap的深度转换

scala - Spark Hadoop 广播失败

使用类型参数时的 Scala "takes no type parameters, expected: one"

scala - 如何在Intellij IDEA中运行Spark示例程序

hadoop - 应用程序接受和运行状态之间耗时

r - 使用 R 组合所有行对

android - 在使用 Android SDK 插件编译 Scala Android 多项目期间,X 已被定义为对象 X

scala - Shapeless:检查多态函数的类型约束