我在 numpy 中使用一个大的二维矩阵 (dtype=np.bool
) 乘以一维向量 (dtype=np.uint32
) matmult
(即点积)。
np.matmul(matrix, vector, out=vector)
它工作得很好,但我在进行较大的计算时总是出现内存不足的情况。
令我印象深刻的一件事是点积的结果向量 - 偶然 - 我只关心返回的单位 - 例如,在返回的结果向量中 - 比如说,包含的元素整数 1234,只有 4 个重要,对于 36,只有 6 个重要,依此类推...
这有点遥远,但与二进制整数在溢出时会无缝翻转的方式非常相似 - 例如 int8 将按如下方式递增:254, 255, 256, 0, 1, 2... .
我想知道是否有一种方法可以创建一个只存储半字节(如果不支持半字节,则存储一个字节)的数据类型,这样在任何算术运算中都只携带十进制单位?
这对于常规二进制编码来说几乎肯定是不可能的,因为十进制单位存储在 2 的所有幂中。但是,如果有一个用于 numpy 的 BCD 编码数据类型(或一种有效构造数据类型的方法),那么我也许只能存储任何算术运算的 LSB,并且仍然完美地跟踪单元,默默地丢弃每个算术运算中的其他字节;类似于 numpy 的 int 类型溢出的示例。
我知道我可以从 BCD 向后和向前转换为二进制 - 但这没有捕获要点 - 整个计算必须以 BCD 完成才能工作。任何转换都只会需要更多内存。
无论创建什么dtype
来存储matmult
的结果,都必须足够大以存储length(vector)*max(vector)*max(matrix) [row]) - 这通常是一个无符号的 32 位数字(特别是对于我的问题,它是 522659*9*1)...对于 uint16 来说太大了;但在此之后,我立即用 (result_vector % 10) 丢弃大部分结果,这将存储在 8 位无符号数据类型中。
内存浪费相当大(分析显示结果使用 uint32
意味着在我的情况下需要 ~1TB 内存,如果限制为结果可以存储在 ~254gb 中) uint8
)。
那么有没有什么方法可以通过限制输入的类型来丢弃计算结果 - 使用 BCD 或其他方式?测试表明,如果我将输入向量设置为 int8
,计算将毫无怨言地继续进行,但会按上述方式滚动 - 因此它可以使用正确的类型。
但是,我的猜测是我必须从头开始完全实现 bcd 类型及其所有操作才能做到这一点?或者实现我自己的自定义矩阵计算?
我很乐意这样做,但想先检查一下我是否错过了一个技巧!
最后一件事 - 分析显示 scipy.spare 矩阵无法充分利用矩阵中的零,因此使用此技巧不会节省内存。索引的成本超过了节省的成本,并且比常规 numpy 使用更多的内存。
我研究过使用结构化数据类型和 View ,这似乎是我正在寻找的大致内容,但我认为两者都不符合这里的要求。
非常感谢任何想法。
最佳答案
使用 uint8
类型和一些模除法可以减少一点内存:
import numpy as np
matrix = np.random.choice([True,False],size=(10000,10000))
vector = np.random.randint(0,10000,10000).astype('uint32')
def func1(matrix,vector):
z = np.empty(1,dtype='uint8')
v = np.empty(vector.shape[0],dtype='uint8')
for i,row in enumerate(matrix):
z = np.tensordot(row,vector,axes=(-1,-1))
v[i] = z%10
return v
def func2(matrix,vector):
z = np.empty(1,dtype='uint8')
v = np.empty(vector.shape[0],dtype='uint8')
for i,row in enumerate(matrix):
np.matmul(row,vector,out=z)
v[i] = z%10
return v
matmul
在这种情况下工作得更快。时机明智:
%timeit func1(matrix,vector)
668 ms ± 3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit func2(matrix,vector)
418 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
尝试计算每个总和以迭代地处理 uint8
类型的对象是可行的,但是大多数情况下我认为这是不值得的 - 这里的中间对象可能是大整数,但 func2
函数一次仅生成一个。因此,即使一行和向量的 matmul 是一个很大的数字,一次也只能存储在内存中。
numpy
实际上可能会迭代地写入z
,在这种情况下,问题没有实际意义——可能值得查看源代码来仔细检查这一点,如果这是一个问题。
关于python - numpy 中的二进制编码的十进制 dtype,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59512188/