python - Spark 数据帧随机拆分

标签 python apache-spark pyspark

我有一个 spark 数据框,我想按 0.60、0.20、0.20 的比例分为训练、验证和测试。

我使用了下面的代码:

def data_split(x):
    global data_map_var
    d_map = data_map_var.value
    data_row = x.asDict()
    import random
    rand = random.uniform(0.0,1.0)
    ret_list = ()
    if rand <= 0.6:
        ret_list = (data_row['TRANS'] , d_map[data_row['ITEM']] , data_row['Ratings'] , 'train')
    elif rand <=0.8:
        ret_list = (data_row['TRANS'] , d_map[data_row['ITEM']] , data_row['Ratings'] , 'test')
    else:
        ret_list = (data_row['TRANS'] , d_map[data_row['ITEM']] , data_row['Ratings'] , 'validation')
    return ret_list
​
​
split_sdf = ratings_sdf.map(data_split)
train_sdf = split_sdf.filter(lambda x : x[-1] == 'train').map(lambda x :(x[0],x[1],x[2]))
test_sdf = split_sdf.filter(lambda x : x[-1] == 'test').map(lambda x :(x[0],x[1],x[2]))
validation_sdf = split_sdf.filter(lambda x : x[-1] == 'validation').map(lambda x :(x[0],x[1],x[2]))
​
print "Total Records in Original Ratings RDD is {}".format(split_sdf.count())
​
print "Total Records in training data RDD is {}".format(train_sdf.count())
​
print "Total Records in validation data RDD is {}".format(validation_sdf.count())
​
print "Total Records in test data RDD is {}".format(test_sdf.count())
​
​
#help(ratings_sdf)
Total Records in Original Ratings RDD is 300001
Total Records in training data RDD is 180321
Total Records in validation data RDD is 59763
Total Records in test data RDD is 59837

我的原始数据框是 ratings_sdf,我用它来传递进行拆分的映射器函数。

如果您检查训练总和,验证和测试总和不计入拆分(原始评级)计数。这些数字在每次运行代码时都会发生变化。

剩余的记录去了哪里,为什么总和不相等?

最佳答案

TL;DR 如果你想拆分 DataFrame 使用 randomSplit method :

ratings_sdf.randomSplit([0.6, 0.2, 0.2])

您的代码在多个层面上都是错误的,但有两个基本问题导致它无法修复:

  • Spark 转换可以计算任意次数,您使用的函数应该是引用透明的并且没有副作用。您的代码多次评估 split_sdf 并且您使用有状态 RNG data_split 因此每次结果都不同。

    这会导致您描述的行为,其中每个 child 看到父 RDD 的不同状态。

  • 您没有正确初始化 RNG,因此您获得的随机值不是独立的。

关于python - Spark 数据帧随机拆分,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40293970/

相关文章:

amazon-web-services - 创建 SparkUI 历史服务器的 CF 模板失败

python - 从字典中提取字符串

Python 和 PyQt。剪贴板内容到列表小部件

python - Golang单元测试python函数

python - PySpark 新列,从整数列表中选择值

scala - scala dataframe 中的collect_list 将收集固定列号间隔内的行

java - 避免在 JavaPairRDD Apache Spark 中进行 Group By

Python - 在动态表中添加 mysql 条目

python - 使用 ASCII 拉取按字母顺序排序的最大长度子字符串

xml - Scala:将 xml 数据帧转换为 csv 文件