apache-spark - pyspark:CrossValidator 不起作用

标签 apache-spark pyspark apache-spark-mllib apache-spark-ml

我正在尝试调整 ALS 的参数,但始终选择第一个参数作为最佳选项

from pyspark.sql import SQLContext
from pyspark import SparkConf, SparkContext
from pyspark.ml.recommendation import ALS
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import RegressionEvaluator
from math import sqrt

from operator import add

conf = (SparkConf()
         .setMaster("local[4]")
         .setAppName("Myapp")
         .set("spark.executor.memory", "2g"))
sc = SparkContext(conf = conf)

sqlContext = SQLContext(sc)
def computeRmse(data):
    return (sqrt(data.map(lambda x: (x[2] - x[3]) ** 2).reduce(add) / float(data.count())))

dfRatings = sqlContext.createDataFrame([(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
                                 ["user", "item", "rating"])

lr1 = ALS()
grid1 = ParamGridBuilder().addGrid(lr1.regParam, [1.0,0.005,2.0]).build()
evaluator1 = RegressionEvaluator(predictionCol=lr1.getPredictionCol(),labelCol=lr1.getRatingCol(), metricName='rmse')
cv1 = CrossValidator(estimator=lr1, estimatorParamMaps=grid1, evaluator=evaluator1, numFolds=2)
cvModel1 = cv1.fit(dfRatings)
a=cvModel1.transform(dfRatings)
print ('rmse with cross validation: {}'.format(computeRmse(a)))

for reg_param in (1.0,0.005,2.0):
    lr = ALS(regParam=reg_param)
    model = lr.fit(dfRatings)
    print ('reg_param: {}, rmse: {}'.format(reg_param,computeRmse(model.transform(dfRatings))))

输出:
交叉验证的均方根误差:1.1820489116858794
reg_param: 1.0, rmse: 1.1820489116858794
reg_param: 0.005, rmse: 0.001573816765686575
reg_param:2.0,rmse:2.1056964491942787

有什么帮助吗?

提前致谢,

最佳答案

抛开其他问题不谈,您根本就没有使用足够的数据来执行有意义的交叉验证和评估。正如我在 Spark ALS predictAll returns empty 中所解释和说明的那样当训练集中缺少用户或项目时,ALS 无法提供预测。

这意味着交叉验证期间的每个分割都将具有未定义的预测,并且总体评估将是未定义的。因此,CrossValidator 将返回第一个可能的模型,因为从它的角度来看,您训练的所有模型都同样糟糕。

关于apache-spark - pyspark:CrossValidator 不起作用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38417431/

相关文章:

python - 尝试使用 pyspark 加载保存的 Spark 模型时出现 "empty collection"错误

scala - 如何在 Scala Spark 中将稀疏向量转换为密集向量?

apache-spark - Spark Master 填充临时目录

amazon-web-services - Spark : Exception in thread "dag-scheduler-event-loop" java. lang.OutOfMemoryError: Java 堆空间

python - 如何在 PySpark 中使用窗口函数?

apache-spark - Firehose JSON -> S3 Parquet -> ETL Spark,错误 : Unable to infer schema for Parquet

apache-spark - Spark /pySpark : Best way to read small binary data files

r - 是什么导致 R 在处理大型数据集时崩溃?

python - 如何在spark中设置驱动程序的python版本?

scala - 将 QuantileDiscretizer 应用于 DataFrame 中的所有列