python - 如何根据另一个二维数组中给出的索引对二维数组进行切片

标签 python numpy numpy-ndarray array-broadcasting numpy-slicing

我有一个MxN名为 A 的数组其中存储了我想要的数据。我还有一个M x N2数组 B它存储数组索引,和 N2<N 。每行B存储我想从 A 获取该行的元素的索引。例如,以下代码适用于我:

A_reduced = np.zeros((M,N2))
for i in range(M):
    A_reduced[i,:] = A[i,B[i,:]]

是否有任何“矢量化”方法可以从A中提取所需的元素基于B而不是循环遍历每一行?

最佳答案

您可以利用数组索引并使用 reshape :

# set up M=N=4, N2=2
a = np.arange(16).reshape(4,4)
b = np.array([[1,2],[0,1],[2,3],[1,3]])

row_idx = np.repeat(np.arange(b.shape[0]),b.shape[1])
col_idx = b.ravel()

# output:
a[row_idx, col_idx].reshape(b.shape)

输出:

array([[ 1,  2],
       [ 4,  5],
       [10, 11],
       [13, 15]])

更新:另一个类似的解决方案

row_idx = np.repeat(np.arange(b.shape[0]),b.shape[1]).reshape(b.shape)

# output
a[row_idx,b]

关于python - 如何根据另一个二维数组中给出的索引对二维数组进行切片,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61400726/

相关文章:

python - 50% 处的 CDF x 值和平均值不显示相同的数字

python - 迭代 python 数组并找到 50 个值的平均值/最小值/最大值

python - PyPI 贝塞尔曲线 0.8.0 : Minimum number of points needed to plot a smooth bezier curve?

python - 如何在 Google Cloud Functions 中查看未捕获异常的堆栈跟踪?

python - 在python中初始化矩阵

python - 如何使用 python 将其所有索引值的总和添加到 numpy ndarray 的每个元素?

python - 对数组进行切片以排除单个元素

python - 在生产中使用带有 zeromq 的 Flask 的合适方法是什么?

python - 迭代定义的 Numpy 数组创建

python - 如何在 Python 中生成日志均匀分布?