我正在尝试使用 armadillo 进行线性回归,如下面的函数所示:
void compute_weights()
{
printf("transpose\n");
const mat &xt(X.t());
printf("inverse\n");
mat xd;
printf("mul\n");
xd = (xt * X);
printf("inv\n");
xd = xd.i();
printf("mul2\n");
xd = xd * xt;
printf("mul3\n");
W = xd * Y;
}
我已将其拆分,以便我可以看到程序变得如此庞大时发生了什么。矩阵 X 有 64 列和超过 2300 万行。转置并不算太糟糕,但是第一次乘法会导致内存占用完全爆炸。现在,据我了解,如果我乘以 X.t() * X,矩阵乘积的每个元素将是 X 的一列和 X.t() 的一行的点积,结果应该是一个 64x64 矩阵。
当然要花很长时间,但是为什么内存会突然爆到将近30G呢?
然后它似乎卡在那个内存上,然后当它进行第二次乘法时,它太多了,操作系统会因为变得如此之大而杀死它。
有没有一种方法可以在不占用太多内存的情况下计算乘积?那段内存能回收吗?有没有更好的方法来表示这些计算?
最佳答案
除非您使用大型工作站,否则您不可能一次性完成整个乘法运算。正如 hbrerkere 所说,您的初始消耗量约为 22 GB。因此,您要么为此做好准备,要么另辟蹊径。
如果您没有这样的工作站,另一种方法是自己进行乘法运算,然后将其并行化。方法如下:
- 不要将整个矩阵加载到内存中,而是加载其中的一部分。
- 加载大约一百万行 X,并将其存储在某个地方。
- 加载 Y 的一百万列
- 使用
std::transform
使用二元运算符std::multiplies
乘以你加载的部分(这将利用你的处理器的矢量化,并使其快速),并填写你计算的部分结果。 - 加载矩阵的下一部分,然后重复
这不会那么有效,但它会起作用。另一种选择是在将矩阵分解为更小的矩阵后考虑使用 Armadillo ,其乘法将产生子结果。
这两种方法都比完全乘法慢得多,原因有两个:
- 从内存中加载和删除数据的开销
- 矩阵乘法已经是一个复杂度为 O(N^3) 的问题...现在拆分你的乘法是复杂度为 O(N^2),所以它会变成复杂度为 O(N^6)...
祝你好运!
关于c++ - Armadillo:矩阵乘法占用大量内存,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42501116/