apache-spark - 在 Spark 中计算逐点互信息

标签 apache-spark apache-spark-mllib

我正在尝试计算 pointwise mutual information (PMI)。

enter image description here

我在这里分别为 p(x, y) 和 p(x) 定义了两个 RDD:

pii: RDD[((String, String), Double)]
 pi: RDD[(String, Double)]

我正在编写的任何用于从 RDD 计算 PMI 的代码 piipi不漂亮。我的做法是先把RDD拉平pii并加入 pi两次同时按摩元组元素。

val pmi = pii.map(x => (x._1._1, (x._1._2, x._1, x._2)))
             .join(pi).values
             .map(x => (x._1._1, (x._1._2, x._1._3, x._2)))
             .join(pi).values
             .map(x => (x._1._1, computePMI(x._1._2, x._1._3, x._2)))
// pmi: org.apache.spark.rdd.RDD[((String, String), Double)]
...
def computePMI(pab: Double, pa: Double, pb: Double) = {
  // handle boundary conditions, etc
  log(pab) - log(pa) - log(pb)
}

显然,这很糟糕。有没有更好的(惯用的)方法来做到这一点?
注意:我可以通过将 log-probs 存储在 pi 中来优化日志。和 pii但选择这样写是为了让问题保持清晰。

最佳答案

使用 broadcast将是一个解决方案。

val bcPi = pi.context.broadcast(pi.collectAsMap())
val pmi = pii.map {
  case ((x, y), pxy) =>
    (x, y) -> computePMI(pxy, bcPi.value.get(x).get, bcPi.value.get(y).get)
}

假设:pi拥有所有 xypii .

关于apache-spark - 在 Spark 中计算逐点互信息,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/29620297/

相关文章:

scala - 如何使用程序化 Spark 提交功能

scala - 如何在spark 2.2中模拟array_join()方法

apache-spark - 撤消缩放数据 pyspark

apache-spark - Spark作业执行时间

scala - 从 DataFrame 到 RDD[LabeledPoint]

scala - 尽管文件大小超出了执行器内存,但如何使用一个分区将数据帧写入 csv 文件

python - PySpark:添加一个新列,其中包含从列创建的元组

apache-spark - Spark2.3.0-bin-without-hadoop,docker-image-tool.sh 缺少 hadoop jar

java - 为什么 StreamingKMeans 聚类中心与常规 Kmeans 不同

apache-spark - 在理解 MLlib 中的 LDA 主题模型时遇到麻烦