如图所示,我想使用spark提取数据。
DataSetTest ro1 = new DataSetTest("apple", "fruit", "red", 3);
DataSetTest ro2 = new DataSetTest("apple", "fruit", "red", 4);
DataSetTest ro3 = new DataSetTest("car", "toy", "red", 1);
DataSetTest ro4 = new DataSetTest("bike", "toy", "white", 2);
DataSetTest ro5 = new DataSetTest("bike", "toy", "red", 5);
DataSetTest ro6 = new DataSetTest("apple", "fruit", "red", 3);
DataSetTest ro7 = new DataSetTest("car", "toy", "white", 7);
DataSetTest ro8 = new DataSetTest("apple", "fruit", "green", 1);
Dataset<Row> df = session.getSqlContext().createDataFrame(Arrays.asList(ro1, ro2, ro3, ro4, ro5, ro6, ro7, ro8), DataSetTest.class);
private void process(){
//1) groupByKey
Dataset<Row> df2 = df.groupBy("keyword", "opt1", "prt2").sum("count");
//2) counting by Opt & calculate the total number
Dataset<Row> df3 = df2.withColumn("fruit_red", **???**)
.withColumn("fruit_green", **???**)
.withColumn("toy_red", **???**)
.withColumn("toy_white",**???**)
.withColumn("total_count", ???);
//3) calculate the percent
Dataset<Row> df4 = df3.withColumn("percent", df3.col("total_count").divide("??sum of total_count??"));
你知道如何数2),3)部分吗?
最佳答案
我不是java专家,但你可以这样做:
Logger.getLogger("org").setLevel(Level.ERROR) ;
DataSetTest ro1 = new DataSetTest("apple", "fruit", "red", 3);
DataSetTest ro2 = new DataSetTest("apple", "fruit", "red", 4);
DataSetTest ro3 = new DataSetTest("car", "toy", "red", 1);
DataSetTest ro4 = new DataSetTest("bike", "toy", "white", 2);
DataSetTest ro5 = new DataSetTest("bike", "toy", "red", 5);
DataSetTest ro6 = new DataSetTest("apple", "fruit", "red", 3);
DataSetTest ro7 = new DataSetTest("car", "toy", "white", 7);
DataSetTest ro8 = new DataSetTest("apple", "fruit", "green", 1);
SparkConf conf = new SparkConf().setMaster("local[*]").setAppName("SaavnAnalyticsProject");
SparkSession sc = SparkSession.builder().config(conf).getOrCreate();
Dataset<Row> df = sc.createDataFrame(Arrays.asList(ro1, ro2, ro3, ro4, ro5, ro6, ro7, ro8), DataSetTest.class);
Dataset<Row> groupedDf = df.groupBy(col("keyword"), col("opt1"), col("opt2")).sum("cnt");
groupedDf = groupedDf.withColumn("concatCol", concat(col("opt1"), lit("_"), col("opt2")));
groupedDf = groupedDf.drop(col("opt1")).drop(col("opt2"));
groupedDf.show();
Dataset<Row> pivotedDF = groupedDf.groupBy(col("keyword")).pivot("concatCol").sum("sum(cnt)").na().fill(0);
String[] cols = ArrayUtil.removeFromArray(pivotedDF.columns(), "keyword");
String exp = String.join(" + ", cols);
System.out.println(exp);
pivotedDF = pivotedDF.withColumn("total", expr(exp));
pivotedDF.show();
结果如下:
+-------+-----------+---------+-------+---------+-----+ |keyword|fruit_green|fruit_red|toy_red|toy_white|total| +-------+-----------+---------+-------+---------+-----+ | apple| 1| 10| 0| 0| 11| | car| 0| 0| 1| 7| 8| | bike| 0| 0| 5| 2| 7| +-------+-----------+---------+-------+---------+-----+
then:
Long sum = pivotedDF.agg(sum("total")).first().getLong(0);
pivotedDF = pivotedDF
.withColumn("sum", lit(sum))
.withColumn("percent", col("total")
.divide(col("sum"))).drop(col("sum"));
结果为:
+-------+-----------+---------+-------+---------+-----+------------------+ |keyword|fruit_green|fruit_red|toy_red|toy_white|total| percent| +-------+-----------+---------+-------+---------+-----+------------------+ | apple| 1| 10| 0| 0| 11|0.4230769230769231| | car| 0| 0| 1| 7| 8|0.3076923076923077| | bike| 0| 0| 5| 2| 7|0.2692307692307692| +-------+-----------+---------+-------+---------+-----+------------------+
您可以使用 python 或 scala 实现更具可读性的代码
关于java - 我想知道如何在 Spark withColumn 中使用过滤器进行计数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64634343/