python - 点积向量化 - NumPy

标签 python arrays numpy vectorization

我有一个函数,我想将剩余的循环向量化。我相信这是正确的,并且我对性能感到满意,但只是想了解更多有关矢量化代码的信息。其功能是:

def f(x, A, c):
    # A : d*d numpy array
    # c : length d numpy array
    # x : N x d or length d numpy array
    x = np.atleast_2d(x)
    b = np.zeros(x.shape[0], dtype=np.bool)
    for row in range(x.shape[0]):
        xmc = x[row, :] - c
        b[row] = xmc.dot(A).dot(xmc) <= 1
    return b

是否可以对函数进行向量化并删除剩余的循环,同时保持其相当简单?当循环中的独立计算无法很好地矢量化时,是否有任何指导原则? N 和 d 的典型值分别为 10000 和 4。谢谢。

最佳答案

你可以像这样矢量化 -

xc = x-c
b_out = ((xc.dot(A))*xc).sum(1) <= 1

您还可以使用np.einsum -

xc = x-c
b_out = np.einsum('ij,jk,ik->i',xc,A,xc) <= 1

运行时测试 -

定义函数:

def org_app(x, A, c):
    x = np.atleast_2d(x)
    b = np.zeros(x.shape[0], dtype=np.bool)
    for row in range(x.shape[0]):
        xmc = x[row, :] - c
        b[row] = xmc.dot(A).dot(xmc) <= 1
    return b

def vectorized_app1(x,A,c):    
    xc = x-c
    return ((xc.dot(A))*xc).sum(1) <= 1

def vectorized_app2(x,A,c):    
    xc = x-c
    return np.einsum('ij,jk,ik->i',xc,A,xc) <= 1

时间安排:

In [266]: N = 20
     ...: d = 20
     ...: A = np.random.rand(d,d)
     ...: c = np.random.rand(d)
     ...: x = np.random.rand(N,d)
     ...: 

In [267]: %timeit org_app(x,A,c)
1000 loops, best of 3: 274 µs per loop

In [268]: %timeit vectorized_app1(x,A,c)
10000 loops, best of 3: 46 µs per loop

In [269]: %timeit vectorized_app2(x,A,c)
10000 loops, best of 3: 63.7 µs per loop

In [270]: N = 100
     ...: d = 100
     ...: A = np.random.rand(d,d)
     ...: c = np.random.rand(d)
     ...: x = np.random.rand(N,d)
     ...: 

In [271]: %timeit org_app(x,A,c)
100 loops, best of 3: 2.74 ms per loop

In [272]: %timeit vectorized_app1(x,A,c)
1000 loops, best of 3: 1.46 ms per loop

In [273]: %timeit vectorized_app2(x,A,c)
100 loops, best of 3: 4.72 ms per loop

关于python - 点积向量化 - NumPy,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/33045481/

相关文章:

python - 在 python 中制作 keras 模型的深拷贝

python - 从解析的 XML 文件中配对子元素和嵌套子元素的值

python - 从另一个字典 python 列表更新字典列表

c++ - 具有自定义索引的数组

c - C 中比较字符串的问题

ruby - 如何在 Ruby 中的哈希中初始化数组

Python numpy, reshape /转换数组避免遍历行

python - 加入列并在行中 reshape

python - 二维 numpy 数组中行或列最常见的元素

python - 如何在 tabula-py 中设置页面范围?