python - 迭代多维 numpy 数组中的向量

标签 python numpy

我有一个 3xNxM 的 numpy 数组 a,我想迭代最后两个轴:a[:,x,y]。不优雅的方法是:

import numpy as np
a = np.arange(60).reshape((3,4,5))
M = np. array([[1,0,0],
               [0,0,0],
               [0,0,-1]])

for x in arange(a.shape[1]):
    for y in arange(a.shape[2]):
        a[:,x,y] = M.dot(a[:,x,y])

这可以用 nditer 来完成吗?这样做的目的是对每个条目执行矩阵乘法,例如a[:,x,y] = M[:,:,x,y].dot(a[:,x,y])。另一种 MATLAB 风格的方法是将 a reshape 为 (3,N*M) 并将 M reshape 为 (3,3*N*M) 并取点积,但这往往会占用大量内存。

最佳答案

虽然随意使用形状可能会使您想要完成的事情更加清晰,但无需考虑太多即可处理此类问题的最简单方法是使用 np.einsum。 :

In [5]: np.einsum('ij, jkl', M, a)
Out[5]: 
array([[[  0,   1,   2,   3,   4],
        [  5,   6,   7,   8,   9],
        [ 10,  11,  12,  13,  14],
        [ 15,  16,  17,  18,  19]],

       [[  0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0]],

       [[-40, -41, -42, -43, -44],
        [-45, -46, -47, -48, -49],
        [-50, -51, -52, -53, -54],
        [-55, -56, -57, -58, -59]]])

此外,它通常还附带绩效奖金:

In [17]: a = np.random.randint(256, size=(3, 1000, 2000))

In [18]: %timeit np.dot(M, a.swapaxes(0,1))
10 loops, best of 3: 116 ms per loop

In [19]: %timeit np.einsum('ij, jkl', M, a)
10 loops, best of 3: 60.7 ms per loop

编辑 einsum 是非常强大的巫术。您还可以按照以下评论中的 OP 要求进行操作:

>>> a = np.arange(60).reshape((3,4,5))
>>> M = np.array([[1,0,0], [0,0,0], [0,0,-1]])
>>> M = M.reshape((3,3,1,1)).repeat(4,axis=2).repeat(5,axis=3)
>>> np.einsum('ijkl,jkl->ikl', M, b)
array([[[  0,   1,   2,   3,   4],
        [  5,   6,   7,   8,   9],
        [ 10,  11,  12,  13,  14],
        [ 15,  16,  17,  18,  19]],

       [[  0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0]],

       [[-40, -41, -42, -43, -44],
        [-45, -46, -47, -48, -49],
        [-50, -51, -52, -53, -54],
        [-55, -56, -57, -58, -59]]])

关于python - 迭代多维 numpy 数组中的向量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/15015471/

相关文章:

python - 省时的宽到长转换 Pandas

python-2.7 - conda 更新 scikit-learn(还有 scipy 和 numpy)

python - 从 CSV 合并 numpy ndarray

python - 在 Numpy 数组的列中进行条件替换

python - 比较 Pandas 中数据帧的标题

python - Pyside:QLineEdit 接受多个输入

python - numpy 数据类型的最大允许值

python - np.isnan() == False,但 np.isnan() 不是 False

python - GeoAlchemy ST_DWithin 实现

python - Pandas 数据框的动态合并