apache-spark - 读取自定义 pyspark 转换器

标签 apache-spark pyspark pipeline apache-spark-ml

在搞了很长一段时间之后,在 Spark 2.3 中,我终于能够保存一个纯 python 自定义转换器。但是在重新加载变压器时出现错误。

我检查了保存的内容,并找到了保存在 HDFS 文件中的所有相关变量。如果有人能发现我在这个简单的变压器中缺少什么,那就太好了。

from pyspark.ml import Transformer
from pyspark.ml.param.shared import Param,Params,TypeConverters

class AggregateTransformer(Transformer,DefaultParamsWritable,DefaultParamsReadable):
    aggCols = Param(Params._dummy(), "aggCols", "",TypeConverters.toListString)
    valCols = Param(Params._dummy(), "valCols", "",TypeConverters.toListString)

    def __init__(self,aggCols,valCols):
        super(AggregateTransformer, self).__init__()
        self._setDefault(aggCols=[''])
        self._set(aggCols = aggCols)
        self._setDefault(valCols=[''])
        self._set(valCols = valCols)

    def getAggCols(self):
        return self.getOrDefault(self.aggCols)

    def setAggCols(self, aggCols):
        self._set(aggCols=aggCols)

    def getValCols(self):
        return self.getOrDefault(self.valCols)

    def setValCols(self, valCols):
        self._set(valCols=valCols)

    def _transform(self, dataset):
        aggFuncs = []
        for valCol in self.getValCols():
            aggFuncs.append(F.sum(valCol).alias("sum_"+valCol))
            aggFuncs.append(F.min(valCol).alias("min_"+valCol))
            aggFuncs.append(F.max(valCol).alias("max_"+valCol))
            aggFuncs.append(F.count(valCol).alias("cnt_"+valCol))
            aggFuncs.append(F.avg(valCol).alias("avg_"+valCol))
            aggFuncs.append(F.stddev(valCol).alias("stddev_"+valCol))

        dataset = dataset.groupBy(self.getAggCols()).agg(*aggFuncs)
        return dataset

保存后加载此转换器的实例时出现此错误。
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-172-44e20f7e3842> in <module>()
----> 1 x = agg.load("/tmp/test")

/usr/hdp/current/spark2.3-client/python/pyspark/ml/util.py in load(cls, path)
    309     def load(cls, path):
    310         """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
--> 311         return cls.read().load(path)
    312 
    313 

/usr/hdp/current/spark2.3-client/python/pyspark/ml/util.py in load(self, path)
    482         metadata = DefaultParamsReader.loadMetadata(path, self.sc)
    483         py_type = DefaultParamsReader.__get_class(metadata['class'])
--> 484         instance = py_type()
    485         instance._resetUid(metadata['uid'])
    486         DefaultParamsReader.getAndSetParams(instance, metadata)

TypeError: __init__() missing 2 required positional arguments: 'aggCols' and 'valCols'

最佳答案

想出了答案!

问题是读取器正在初始化一个新的 Transformer 类,但是我的 AggregateTransformer 的 init 函数没有参数的默认值。

因此更改以下代码行解决了问题!

def __init__(self,aggCols=[],valCols=[]):

将这个答案和问题留在这里,因为我很难找到一个可以在任何地方保存和回读的纯 python 转换器的工作示例!它可以帮助寻找这个的人。

关于apache-spark - 读取自定义 pyspark 转换器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52443326/

相关文章:

python-3.x - 如何在 AWS EMR 上设置 PYTHONHASHSEED

apache-spark - 如何分析 pyspark 工作

apache-spark - 获取 java.lang.LinkageError : ClassCastException when use spark sql hivesql on yarn

scala - 如何将 ML 稀疏向量类型的变量转换为 MLlib 稀疏向量类型?

apache-spark - 如何删除超过 X 天/年的 Databricks 数据?

python - PySpark 按条件计算值

apache-spark - Spark : Using null checking in a CASE WHEN expression to protect against type errors

python - Scikit 学习管道类型错误 : zip argument #2 must support iteration

python - 在一个对象中处理标签编码、转换和估计

python - 将 ColumnTransformer 用于管道时出现 AttributeError