python - 具有不同位置的 NumPy 索引

标签 python arrays performance numpy vectorization

我有一个形状为 (A, B, C) 的数组 input_data 和一个形状为 (B,) 的数组 ind。我想循环 B 轴并获取元素 C[B[i]] 和 C[B[i]+1] 的总和。所需输出的形状为 (A, B)。我有以下有效的代码,但我觉得由于 B 轴基于索引的循环而效率低下。有没有更有效的方法?

import numpy as np

input_data = np.random.rand(2, 6, 10)
ind = [ 2, 3, 5, 6, 5, 4 ]

out = np.zeros( ( input_data.shape[0], input_data.shape[1] ) )

for i in range( len(ind) ):
    d = input_data[:, i, ind[i]:ind[i]+2]
    out[:, i] = np.sum(d, axis = 1)

根据 Divakar 的回答进行编辑:

import timeit
import numpy as np

N = 1000

input_data = np.random.rand(10, N, 5000)
ind = ( 4999 * np.random.rand(N) ).astype(np.int)

def test_1(): # Old loop-based method
    out = np.zeros( ( input_data.shape[0], input_data.shape[1] ) )

    for i in range( len(ind) ):
        d = input_data[:, i, ind[i]:ind[i]+2]
        out[:, i] = np.sum(d, axis = 1)
    return out

def test_2(): 
    extent = 2 # Comes from 2 in "ind[i]:ind[i]+2"

    m,n,r = input_data.shape
    idx = (np.arange(n)*r + ind)[:,None] + np.arange(extent)
    out1 = input_data.reshape(m,-1)[:,idx].reshape(m,n,-1).sum(2)
    return out1

print timeit.timeit(stmt = test_1, number = 1000)
print timeit.timeit(stmt = test_2, number = 1000)

print np.all( test_1() == test_2(), keepdims = True )

>> 7.70429363482
>> 0.392034666757
>> [[ True]]

最佳答案

这是使用 linear indexing 的矢量化方法在 broadcasting 的帮助下。我们合并输入数组的最后两个轴,计算与最后两个轴对应的线性索引,执行切片并 reshape 回 3D 形状。最后,我们沿最后一个轴求和以获得所需的输出。实现看起来像这样 -

extent = 2 # Comes from 2 in "ind[i]:ind[i]+2"

m,n,r = input_data.shape
idx = (np.arange(n)*r + ind)[:,None] + np.arange(extent)
out1 = input_data.reshape(m,-1)[:,idx].reshape(m,n,-1).sum(2)

如果范围始终为2,如问题中所述 - “...元素 C[B[i]] 和的总和C[B[i]+1]",那么你可以简单地这样做 -

m,n,r = input_data.shape
ind_arr = np.array(ind)
axis1_r = np.arange(n)
out2 = input_data[:,axis1_r,ind_arr] + input_data[:,axis1_r,ind_arr+1]

关于python - 具有不同位置的 NumPy 索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35265914/

相关文章:

javascript - 只更新可见的 DOM 元素?

python - PostgreSQL 查询速度慢,问题是什么?

python - 仅当满足每行元素的条件时,才计算 2D 数组特定列的均值和方差

python - 从 JSON 字符串中转义数据的任何巧妙方法?

c# - 使用 Random 类随机化二维数组

javascript - JS 中检查数组是否存在的正确方法是什么?

c# - System.Threading.Timer卡住计算机

python - 使用 `exec` 调用时如何更新局部变量?

python - 在 4store 中添加三元组

c# - 在 C# 中将对象数组转换为另一种类型的简单方法