这是我的数据框
df.groupBy($"label").count.show
+-----+---------+
|label| count|
+-----+---------+
| 0.0|400000000|
| 1.0| 10000000|
+-----+---------+
我正在尝试使用以下命令对标签 == 0.0 的记录进行二次采样:
val r = scala.util.Random
val df2 = df.filter($"label" === 1.0 || r.nextDouble > 0.5) // keep 50% of 0.0
我的输出如下所示:
df2.groupBy($"label").count.show
+-----+--------+
|label| count|
+-----+--------+
| 1.0|10000000|
+-----+--------+
最佳答案
r.nextDouble
是表达式中的常量,因此实际评估与您的意思有很大不同。根据实际采样值,它是:
scala> r.setSeed(0)
scala> $"label" === 1.0 || r.nextDouble > 0.5
res0: org.apache.spark.sql.Column = ((label = 1.0) OR true)
或
scala> r.setSeed(4096)
scala> $"label" === 1.0 || r.nextDouble > 0.5
res3: org.apache.spark.sql.Column = ((label = 1.0) OR false)
简化后就是:
true
(保留所有记录)或
label = 1.0
(仅保留您观察到的情况)。
要生成随机数,您应该使用 corresponding SQL function
scala> import org.apache.spark.sql.functions.rand
import org.apache.spark.sql.functions.rand
scala> $"label" === 1.0 || rand > 0.5
res1: org.apache.spark.sql.Column = ((label = 1.0) OR (rand(3801516599083917286) > 0.5))
虽然Spark已经提供了分层采样工具:
df.stat.sampleBy(
"label", // column
Map(0.0 -> 0.5, 1.0 -> 1.0), // fractions
42 // seed
)
关于scala - Spark DataFrame 过滤器无法按预期与随机一起工作,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54769360/