我有一个 3D 数值数据文件,我以 block 的形式读取该数据(因为以 block 的形式读取比单个索引更快)。例如,假设"file"中有一个 MxNx30 数组,我将创建一个如下所示的 RDD:
def read(ind):
f = customFileOpener(file)
return f['data'][:,:,ind[0]:ind[-1]+1]
indices = [[0,9],[10,19],[20,29]]
rdd = sc.parallelize(indices,3).map(lambda v:read(v))
rdd.count()
因此,3 个分区中的每一个都有一个大小为 MxNx10 的 numpy.ndarray 元素。
现在,我想拆分每个元素,因此在每个分区中,我有 10 个元素,每个元素都是一个 MxN 数组。我尝试使用 flatMap() 来实现此目的,但收到错误“NoneType 对象不可迭代”:
def splitArr(arr):
Nmid = arr.shape[-1]
out = []
for i in range(0,Nmid):
out.append(arr[...,i])
return out
rdd2 = rdd.flatMap(lambda v: splitArr(v))
rdd2.count()
正确的做法是什么?关键点是(a)我需要从文件中读取 block 数据,以及(b)分割数据,使元素大小为MxN(最好保持分区结构)。
最佳答案
据我了解您的描述,类似这样的事情应该可以解决问题:
rdd.flatMap(lambda arr: (x for x in np.rollaxis(arr, 2)))
或者,如果您更喜欢单独的功能:
def splitArr(arr):
for x in np.rollaxis(arr, 2):
yield x
rdd.flatMap(splitArr)
关于python - 在 Pyspark 中拆分 RDD 分区中的数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/32962112/