python - 在 Pyspark 中拆分 RDD 分区中的数组

标签 python numpy apache-spark pyspark

我有一个 3D 数值数据文件,我以 block 的形式读取该数据(因为以 block 的形式读取比单个索引更快)。例如,假设"file"中有一个 MxNx30 数组,我将创建一个如下所示的 RDD:

def read(ind):
    f = customFileOpener(file)
    return f['data'][:,:,ind[0]:ind[-1]+1]

indices = [[0,9],[10,19],[20,29]]
rdd = sc.parallelize(indices,3).map(lambda v:read(v))
rdd.count()

因此,3 个分区中的每一个都有一个大小为 MxNx10 的 numpy.ndarray 元素。

现在,我想拆分每个元素,因此在每个分区中,我有 10 个元素,每个元素都是一个 MxN 数组。我尝试使用 flatMap() 来实现此目的,但收到错误“NoneType 对象不可迭代”:

def splitArr(arr):
    Nmid = arr.shape[-1]
    out = []
    for i in range(0,Nmid):
         out.append(arr[...,i])
    return out

rdd2 = rdd.flatMap(lambda v: splitArr(v))
rdd2.count()

正确的做法是什么?关键点是(a)我需要从文件中读取 block 数据,以及(b)分割数据,使元素大小为MxN(最好保持分区结构)。

最佳答案

据我了解您的描述,类似这样的事情应该可以解决问题:

rdd.flatMap(lambda arr: (x for x in np.rollaxis(arr, 2)))

或者,如果您更喜欢单独的功能:

def splitArr(arr):
    for x in np.rollaxis(arr, 2):
        yield x

rdd.flatMap(splitArr)

关于python - 在 Pyspark 中拆分 RDD 分区中的数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/32962112/

相关文章:

python - 访问另一个模块的 __all__

python - Numpy.where 与值列表一起使用

python - F2PY 无法看到模块范围变量

python - 非线性函数的 numpy 爱因斯坦求和

apache-spark - Kerberos Cloudera Hadoop 的 livy curl 请求错误

python - 将 CSV 表转换为 Redis 数据结构

python - 如何将 geckodriver 放入 PATH?

scala - 使用一种热编码和向量汇编器与向量索引器来处理分类特征

python以mm/dd/yyyy格式获取文件的时间戳

python - 没有 SQLContext 的 pyspark 中的 clearCache