python - 将 pyspark 数据帧中一列的字符串列表转换为用于 one-hot 编码的字符串

标签 python dataframe apache-spark pyspark

我的问题与我之前的问题相关: transform columns values to columns in pyspark dataframe

我创建了一个表“my_df”(pyspark 中的数据框):

+----+--------+---------------------------------+
|id  |payment        |shop                      |
+----+--------+---------------------------------+
|dapd|[credit, cash] |[retail, on-line]         |
|wrfr|[cash, debit]  |[supermarket, brand store]|
+----+--------+---------------------------------+

现在,我需要对表进行聚类,以便可以找到“id”的相似性。 我一开始尝试k-means。因此,我需要通过 one-hot 编码将分类值转换为数值。 我指的是How to handle categorical features with spark-ml?

我的代码:

from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, OneHotEncoderEstimator

inputs, my_indx_list = [], []
for a_col in my_df.columns: 
  my_indx = StringIndexer(inputCol = a_col, outputCol = a_col + "_index")
  inputs.append(my_indx.getOutputCol())
  my_indx_list.append(my_indx)

  encoder = OneHotEncoderEstimator(inputCols=inputs, outputCols=[x + "_vector" for x in inputs])
  a_pipeline = Pipeline(stages = my_indx_list + [encoder])
  pipeline.fit(my_df).transform(my_df).show() # error here !

但是,我遇到了错误:

A column must be either string type or numeric type, but got ArrayType(StringType,true)

那么,我该如何解决这个问题呢?

我的想法:对每列的列表值进行排序,并将列表中的每个字符串连接成每列的长字符串。

但是,对于每一列,这些值是一些调查问题的答案,并且每个答案具有相同的权重。 我不知道如何解决?

谢谢

更新

根据建议的解决方案,它可以工作,但速度非常慢。 在300GB内存、32核的集群上,耗时约3.5小时。

我的代码:

   from pyspark.ml.feature import CountVectorizer
   tmp_df = original_df # 3.5 million rows and 300 columns

   for a_col in original_df.columns: 
        a_vec = CountVectorizer(inputCol = a_col, outputCol = a_col + "_index", binary=True)
        tmp_df = a_vec.fit(tmp_df).transform(tmp_df)

  tmp_df.show()

“original_df”有 350 万行和 300 列。

如何才能加快速度?

谢谢

最佳答案

@jxc 建议在您的案例中巧妙地使用 CountVectorizer 进行 one-hot 编码,这通常用于计算自然语言处理中的标记。

使用CountVectorizer可以让您省去用OneHotEncoderEstimator处理explodecollect_set的麻烦;如果您尝试使用 udf 来实现它,情况会更糟。

鉴于此数据框,

df = spark.createDataFrame([
                            {'id': 'dapd', 'payment': ['credit', 'cash'], 'shop': ['retail', 'on-line']},
                            {'id': 'wrfr', 'payment': ['cash', 'debit'], 'shop': ['supermarket', 'brand store']}
                           ])
df.show()

+----+--------------+--------------------+
|  id|       payment|                shop|
+----+--------------+--------------------+
|dapd|[credit, cash]|   [retail, on-line]|
|wrfr| [cash, debit]|[supermarket, bra...|
+----+--------------+--------------------+

您可以通过将字符串数组视为自然语言处理中的标记来进行单热编码。请注意使用 binary=True 强制其仅返回 0 或 1。

from pyspark.ml.feature import CountVectorizer

payment_cv = CountVectorizer(inputCol="payment", outputCol="paymentEnc", binary=True)
first_res_df = payment_cv.fit(df).transform(df)

shop_cv = CountVectorizer(inputCol="shop", outputCol="shopEnc", binary=True)
final_res_df = shop_cv.fit(first_res_df).transform(first_res_df)

final_res_df.show()

+----+--------------+--------------------+-------------------+-------------------+
|  id|       payment|                shop|         paymentEnc|            shopEnc|
+----+--------------+--------------------+-------------------+-------------------+
|dapd|[credit, cash]|   [retail, on-line]|(3,[0,2],[1.0,1.0])|(4,[0,3],[1.0,1.0])|
|wrfr| [cash, debit]|[supermarket, bra...|(3,[0,1],[1.0,1.0])|(4,[1,2],[1.0,1.0])|
+----+--------------+--------------------+-------------------+-------------------+

关于python - 将 pyspark 数据帧中一列的字符串列表转换为用于 one-hot 编码的字符串,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58985931/

相关文章:

Python:每个固定时间窗口从 pandas 数据帧中提取行

java - (Spark skewed join) 如何在没有内存问题的情况下连接两个具有高度重复键的大型 Spark RDD?

apache-spark - 如何在Spark结构化流中将两个流df写入MySQL中的两个不同表中?

python - 从先前终止的连接到数据库提交 SQLite 中的现有日志文件

python 对文件的扩展名/名称进行排序

python - 旋转数据框以自动创建列

python - 从列范围返回新数据框( Pandas )

hadoop - Spark/Hadoop/Yarn集群通信需要外部ip?

python - 使用 Python 将 Excel DATE(而非日期时间)数据插入 SQL Server

python - 如何在centos上为一个用户安装python?