math - 16 位定点矩阵乘法

标签 math matrix matrix-multiplication fixed-point

我需要在神经网络中的不同层之间执行矩阵乘法。即:W0, W1, W2, ... Wn 是神经网络的权重,输入是data。结果矩阵是:

Out1 = data * W0
Out2 = Out1 * W1
Out3 = Out2 * W2
.
.
.
OutN = Out(N-1) * Wn

我知道权重矩阵中的绝对最大值,而且我也知道输入数据范围值是从 0 到 1(输入已标准化)。矩阵乘法是16位定点乘法。权重被缩放到最佳格式点。例如:如果W0中的绝对最大值是2.5,我知道整数部分的最小位数是2,小数部分的位数将是14。因为数据输入在范围 [0,1] 我也知道整数和小数位是 1.15。

我的问题是:如何知道结果矩阵中整数部分的最小位数以避免溢出?有没有办法研究和推断矩阵乘法的最大值?我知道矩阵的行列式和范数,但是,我认为问题出在矩阵行列中的连续负值或正值。例如,如果我有这个行向量和这个列向量,并且结果是 8 位定点:

A = [1, 2, 3, 4, 5, 6, -7, -8]
B = [1, 2, 3, 4, 5, 6, 7, 8]
A * B = (1*1) + (2*2) + (3*3) + (4*4) + (5*5) + (6*6) + (7*-7) + (8*8) = 90 - 49 + -68

当累加器小于64时,虽然最终结果在[-64,63]之间,但会发生溢出。

另一个例子:如果我有这个行向量和这个列向量,并且结果是 8 位定点:

A = [1, -2, 3, -4, 5, -6, 7, -8]
B = [1, 2, 3, 4, 5, 6, 7, 8]
A * B = (1*1) - (2*2) + (3*3) - (4*4) + (5*5) - (6*6) + (7*7) - (8*8) = -36

任意时刻的累加器都超出了8位的最大范围。

总结:我正在寻找一种分析权重矩阵的方法,以避免总和累加器溢出。我进行矩阵乘法的方法是(仅作为矩阵 A 和 B 已升级为 1.15 格式的示例):

A1 --> 1.15 bits
B1 --> 1.15 bits
A2 --> 1.15 bits
B2 --> 1.15 bits
mult_1 = (A1 * B1) >> 2^15; // Right shift to alineate the operands
mult_2 = (A2 * B2) >> 2^15; // Right shift to alineate the operands
sum_acc = mult_1 + mult_2;  // Sum accumulator

最佳答案

让我们以 %3.13 定点格式的 n=100 维点积(它是任何矩阵乘法或卷积的一部分)为例。

  1. 整数位

    %4.13 中的最大值略低于 2^4,因此我们考虑一下它是:15.999999

    现在,n 维点积具有 n 次乘法和 n-1 次加法。

    15.999999*15.999999 + 15.999999*15.999999 + .... + 15.999999*15.999999
    

    每次乘法都会对整数位求和

    15.999999*15.999999 = 255.999999 -> ceil(log2(255)) = 8 = 2*(4)-> %8.13
    

    现在这个值被添加了 99 次,所以它与:

    255.999999*99 = 25343.999999 -> ceil(log2(25343)) = 15 = ceil(8+log2(99)) -> %15.13
    

    因此,如果 n 是维度数,而 i 是结果所需的整数位数:

    i' = ceil((i*2)+log2(n-1)) 
    

    整数位...所以:

    %1.? -> 99*( 1.999999^2) =   395.99 -> % 9.?
    %2.? -> 99*( 3.999999^2) =  1583.99 -> %11.?
    %3.? -> 99*( 7.999999^2) =  6335.99 -> %13.?
    %4.? -> 99*(15.999999^2) = 25343.99 -> %15.?
    
    i(1) = ceil((1*2)+log2(99)) = ceil(2+6.626) = 9
    i(2) = ceil((2*2)+log2(99)) = ceil(4+6.626) = 11
    i(3) = ceil((3*2)+log2(99)) = ceil(6+6.626) = 13
    i(4) = ceil((4*2)+log2(99)) = ceil(8+6.626) = 15
    
  2. 小数位

    好吧,让我们看看乘法会发生什么:

    0.1b^2 = 0.01b        -> %?.1 -> %?.2
    0.01b^2 = 0.0001b     -> %?.2 -> %?.4
    0.001b^2 = 0.000001b  -> %?.3 -> %?.6
    

    so f' = 2*f 其中 f 是小数位数。添加并不改变位宽:

    0.1b*2 = 1.0b         -> %?.1 -> %?.1
    0.01b*2 = 0.1b        -> %?.2 -> %?.2
    0.001b*2 = 0.01b      -> %?.3 -> %?.3
    

    因为结果不会小于操作数。因此,当将小数部分应用于点积时,我们将得到:

    i' = ceil((i*2)+log2(n-1)) 
    f' = 2*f 
    

关于math - 16 位定点矩阵乘法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65048801/

相关文章:

Python-Numpy 矩阵乘法

java - 确定正方形和矩形之间关系的算法

matlab - 用 block 状或聚合的方法随机替换矩阵中的元素

python - 理解 Python 中的 einsum

python - 在 python 中,我们如何找到两个矩阵之间的相关系数?

matlab - 从矩阵列中减去相应的向量值

matlab - 在 MATLAB 中将两个非常大的稀疏矩阵相乘时出现内存不足错误

javascript - 在 Javascript 中使用带有计时器的随机数生成器每次都会给出相同的数字

c++ - C++ 中的积分(数学)

javascript - 在游戏中跳跃