我正在尝试测试在 Spark 文档中找到的这段代码,以便使用 Java 处理 Apache Spark 中的分类功能:
SparkSession spark = SparkSession
.builder().master("local[4]")
.appName("1-of-K encoding Test")
.getOrCreate();
List<Row> data = Arrays.asList(
RowFactory.create(0, "a"),
RowFactory.create(1, "b"),
RowFactory.create(2, "c"),
RowFactory.create(3, "a"),
RowFactory.create(4, "a"),
RowFactory.create(5, "c")
);
StructType schema = new StructType(new StructField[]{
new StructField("id", DataTypes.IntegerType, false,Metadata.empty()),
new StructField("category", DataTypes.StringType, false, Metadata.empty())
});
Dataset<Row> df = spark.createDataFrame(data, schema);
StringIndexerModel indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("categoryIndex")
.fit(df);
但是我收到了这个错误;无法调用 fit 函数
你有什么想法吗?
最佳答案
为什么要在较长的 route 创建 df?更有效的方法是:
import sparkSession.implicits._
val df = sparkSession.sparkContext.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "d"), (4, "e"), (5, "f"))).toDF("id", "category")
val newDf = new StringIndexer()
.setInputCol("category")
.setOutputCol("categoryIndex")
.fit(df)
.transform(df)
.show;
输出:
+---+--------+-------------+
| id|category|categoryIndex|
+---+--------+-------------+
| 0| a| 2.0|
| 1| b| 3.0|
| 2| c| 4.0|
| 3| d| 5.0|
| 4| e| 0.0|
| 5| f| 1.0|
+---+--------+-------------+
关于apache-spark - Java 中的 1-of-k 编码 Apache Spark,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44200383/