python - numpy 获取行和列索引二维数组的所有组合

标签 python numpy

我有一个 2D numpy 数组,如下所示:

import numpy as np
foo = np.array([[(i+1)*(j+1) for i in range(10)] for j in range(5)])

    #array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
    #       [ 2,  4,  6,  8, 10, 12, 14, 16, 18, 20],
    #       [ 3,  6,  9, 12, 15, 18, 21, 24, 27, 30],
    #       [ 4,  8, 12, 16, 20, 24, 28, 32, 36, 40],
    #       [ 5, 10, 15, 20, 25, 30, 35, 40, 45, 50]])

我使用 np.nonzero 创建了一些过滤条件:

csum = np.sum(foo,axis=0)
#array([ 15,  30,  45,  60,  75,  90, 105, 120, 135, 150])
rsum = np.sum(foo,axis=1)
#array([ 55, 110, 165, 220, 275])
cfilter = np.nonzero(csum > 80)
#(array([5, 6, 7, 8, 9]),)
rfilter = np.nonzero(rsum < 165)
#(array([0, 1]),)

现在是否有一些优雅的 numpy 切片方法来获取 foo[r,c] 的所有组合,对于 rfilter 中的 r 和 cfilter 中的 c?即我想获得以下输出:

array([[ 6,  7,  8,  9, 10],
       [12, 14, 16, 18, 20]])

注意:我知道通过基本切片选择从数组中获取 block 很容易,但在更高级的用例中,cfilter 和 rfilter 中的索引不一定彼此相邻。

非常感谢!

最佳答案

要索引叉积,使用np.ix_:

foo[np.ix_(*(rfilter + cfilter))]

您可以直接使用 bool 索引(即不使用 np.nonzero):

foo[np.ix_(np.sum(foo, axis=1) < 165, np.sum(foo, axis=0) > 80)]

请注意,所有 np.ix_ 所做的就是适本地添加轴以提供可以一起广播的索引数组:

>>> np.ix_(*(rfilter + cfilter))
(array([[0],
       [1]]), array([[5, 6, 7, 8, 9]]))

关于python - numpy 获取行和列索引二维数组的所有组合,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/13533719/

相关文章:

python - Django get_context_data() 变量在某处被覆盖

python - 使用 Python 将所有压缩文件提取到同一目录

python - 将 Numpy 存储为 pickled Pandas、Pickled Numpy 或 HDF5

python - 如何将列表中的值插入现有列

python - RollingGroupby 上的 Pandas 聚合方法

Python:每分钟只需要请求 20 次

python - 用另一列中的元素替换数据框中的列 - python

python - 如何保存memcache值直到配额被补充?

python - 将 numpy.void 转换为 numpy.ndarray

python - 替换 Numpy 图像中的像素值