我想在一个非常大的数据集中找到连续的时间戳。这需要使用 Java 在 Spark 中完成(也非常欢迎 Scala 中的代码示例)。
每一行如下所示:
ID、开始时间、结束时间
例如数据集:
[[1,10,15],[1,15,20],[2,10,13],[1,22,33],[2,13,16]]
预期结果是同一 ID 的所有连续时间范围,每个连续时间范围只有开始和结束时间:
[[1, 10, 20],[1, 22, 33], [2, 10, 16]]
我已经尝试过以下操作,但由于未维护订单,因此不起作用。因此我希望有一种更有效的方法来做到这一点
textFile.mapToPair(x -> new Tuple2<>(x[0],new Tuple2<>(x[1], x[2])
.mapValues(x -> new LinkedList<>(Arrays.asList(x)))
.reduceByKey((x,y) -> {
Tuple2<Long, Long> v1 = x.getLast();
Tuple2<Long, Long> v2 = y.getFirst();
Tuple2<Long, Long> v3 = v2;
if(v2._1().equals(v1._2())) {
v3 = new Tuple2<>(v1._1(), v2._2());
x.removeLast();
}
x.addLast(v3);
return x;
})
.flatMapValues(x -> x);
最佳答案
我认为这不是 Spark 问题,而是逻辑问题。 您应该考虑使用多个独立函数的选项:
- 将两个间隔绑定(bind)在一起(我们将其命名为
bindEntries()
) - 将新间隔添加到间隔的间隔累加器中(设为
insertEntry()
)
建议,我们有模拟数据mockData
:
+---+-----+---+
| id|start|end|
+---+-----+---+
| 1| 22| 33|
| 1| 15| 20|
| 1| 10| 15|
| 2| 13| 16|
| 2| 10| 13|
+---+-----+---+
借助这些函数,我对您问题的解决方案将是这样的:
val processed = mockData
.groupByKey(_.id)
.flatMapGroups { (id: Int, it: Iterator[Entry]) =>
processEntries(it)
}
processEntries()
的唯一目标是将每个 id 的所有条目折叠到非相交间隔的集合中。
这是它的签名:
def processEntries(it: Iterator[Entry]): List[Entry] =
it.foldLeft(Nil: List[Entry])(insertEntry)
此函数用于从分组条目中一一获取元素,并将它们一一插入累加器。
处理这种插入的函数insertEntry()
:
def insertEntry(acc: List[Entry], e: Entry): List[Entry] = acc match {
case Nil => e :: Nil
case a :: as =>
val combined = bindEntries(a, e)
combined match {
case x :: y :: Nil => x :: insertEntry(as, y)
case x :: Nil => insertEntry(as, x)
case _ => a :: as
}
}
bindEntries()
函数应该为您处理条目的顺序:
def bindEntries(x: Entry, y: Entry): List[Entry] =
(x.start > y.end, x.end < y.start) match {
case (true, _) => y :: x :: Nil
case (_, true) => x :: y :: Nil
case _ => x.copy(start = x.start min y.start, end = x.end max y.end) :: Nil
}
bindEntries()
将返回一个或两个正确排序条目的列表。
这就是其背后的想法:
insertEntry()
将在插入过程中为您对所有条目进行排序。
毕竟,生成的数据集如下所示:
+---+-----+---+
| id|start|end|
+---+-----+---+
| 1| 10| 20|
| 1| 22| 33|
| 2| 10| 16|
+---+-----+---+
注意:函数 insertEntry()
不是尾递归。
这是进一步优化的良好起点。
还有完整的解决方案:
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Dataset, Encoder, Encoders, SparkSession}
object AdHoc {
Logger.getLogger("org").setLevel(Level.OFF)
def main(args: Array[String]): Unit = {
import spark.implicits._
val processed = mockData
.groupByKey(_.id)
.flatMapGroups { (id, it) =>
processEntries(it)
}
mockData.show()
processed.show()
}
def processEntries(it: Iterator[Entry]): List[Entry] =
it.foldLeft(Nil: List[Entry])(insertEntry)
def insertEntry(acc: List[Entry], e: Entry): List[Entry] = acc match {
case Nil => e :: Nil
case a :: as =>
val combined = bindEntries(a, e)
combined match {
case x :: y :: Nil => x :: insertEntry(as, y)
case x :: Nil => insertEntry(as, x)
case _ => a :: as
}
}
def bindEntries(x: Entry, y: Entry): List[Entry] =
(x.start > y.end, x.end < y.start) match {
case (true, _) => y :: x :: Nil
case (_, true) => x :: y :: Nil
case _ => x.copy(start = x.start min y.start, end = x.end max y.end) :: Nil
}
lazy val mockData: Dataset[Entry] = spark.createDataset(Seq(
Entry(1, 22, 33),
Entry(1, 15, 20),
Entry(1, 10, 15),
Entry(2, 13, 16),
Entry(2, 10, 13)
))
case class Entry(id: Int, start: Int, end: Int)
implicit lazy val entryEncoder: Encoder[Entry] = Encoders.product[Entry]
lazy val spark: SparkSession = SparkSession.builder()
.master("local")
.getOrCreate()
}
关于java - Spark查找连续时间范围,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56384916/