python - 将大矩阵与dask相乘

标签 python matrix dask

我正在做一个基本上可以归结为求解矩阵方程的项目

A.dot(x) = d

哪里A是一个到 2000 年尺寸大约为 10 000 000 的矩阵(我想最终在两个方向上增加它)。

A显然不适合内存,所以这必须并行化。我通过解决 A.T.dot(A).dot(x) = A.T.dot(d) 来做到这一点反而。 A.T将具有 2000 x 2000 的尺寸。可以通过除以 A 来计算。和 d成 block A_id_i , 沿行计算 A_i.T.dot(A_i)A_i.T.dot(d_i) ,并对这些求和。非常适合并行化。我已经能够使用多处理模块实现这一点,但是 1) 由于内存使用,很难进一步扩展(在两个维度上增加 A),并且 2)不漂亮(因此不容易维护) .

Dask 似乎是一个非常有前途的库来解决这两个问题,我也做了一些尝试。我的A矩阵计算起来很复杂:它基于大约 15 个不同的数组(大小等于 A 中的行数),其中一些用于迭代算法以评估相关的勒让德函数。当chunks很小(10000行)时,构建task graph的时间会非常长,而且会占用大量的内存(内存的增加与迭代算法的调用是一致的)。当 block 较大时(50000 行),计算前的内存消耗会小很多,但在计算 A.T.dot(A) 时会迅速耗尽。 .我试过 cache.Chest , 但它显着减慢了计算速度。

任务图必须非常大且复杂 - 调用 A._visualize()崩溃。更简单 A矩阵,它可以直接执行此操作(请参阅@MRocklin 的回复)。我有办法简化它吗?

任何有关如何解决此问题的建议都将不胜感激。

作为玩具示例,我尝试了

A = da.ones((2e3, 1e7), chunks = (2e3, 1e3)) 
(A.T.dot(A)).compute()

这也失败了,耗尽了所有内存,只有一个核心处于事件状态。与 chunks = (2e3, 1e5) , 所有核心几乎立即启动,但是 MemoryError在 1 秒内出现(我当前的计算机上有 15 GB)。 chunks = (2e3, 1e4)更有希望,但它最终也消耗了所有内存。

编辑: 我把toy example test打通了,因为尺寸不对,改了其余的尺寸。正如@MRocklin 所说,它确实适用于正确的尺寸。我添加了一个我现在认为与我的问题更相关的问题。

编辑2: 这是我尝试做的一个非常简单的例子。我认为,问题在于定义 A 中的列所涉及的递归。 .

import dask.array as da

N = 1e6
M = 500

x = da.random.random((N, 1), chunks = 5*M)

# my actual 
A_dict = {0:x}
for i in range(1, M):
    A_dict[i] = 2*A_dict[i-1]
A = da.hstack(tuple(A_dict.values()))
A = A.rechunk((M*5, M))
ATA = A.T.dot(A)

这似乎导致了一个非常复杂的任务图,甚至在计算开始之前就占用了大量内存。

我现在通过将递归放在函数中解决了这个问题,使用 numpy数组,或多或少做 A = x.map_blocks(...) .

作为第二个注意事项,一旦我有了 A矩阵任务图,计算A.T.dot(A)直接似乎会出现一些内存问题(内存使用不是很稳定)。因此,我明确地分块计算,并对结果求和。即使有这些变通办法,dask 也会在速度和可读性方面产生很大差异。

最佳答案

你的输出非常非常大。

>>> A.T.dot(A).shape
(10000000, 10000000)

也许您打算用另一个方向的转置来计算这个?

>>> A.dot(A.T).shape
(2000, 2000)

这仍然需要一段时间(这是一个很大的计算量)但它确实完成了。

关于python - 将大矩阵与dask相乘,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35342769/

相关文章:

python - 从 NumPy 或 PyTorch 中的矩阵获取对角线 "stripe"

c++ - 用于矩阵加法的 Cuda 程序

python - 如何为默认的 dask 调度程序指定线程/进程数

python - 删除具有三个元素的元组的列表中的冗余

python - asfreq 和 resample 之间的区别

python - 在 python 中准备一个返回矩阵有什么好处?

python - dask 数据帧中 .join 的结果似乎取决于 dask 数据帧的生成方式

python - 使用 Dask 处理大型、压缩的 csv 文件

python - 字符串无法正确比较

python - pyinstaller: ImportError: 无法导入名称 _elementpath