python - 是否可以在 numpy 数组切片运算符中组合逻辑条件和限制条件

标签 python arrays numpy slice

我有以下代码可以完全满足我的要求,但速度太慢,因为它涉及不必要的具体化步骤:

### init
a = np.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]])

### condition 1) element 0 has to be larger than 1
### condition 2) limit the output to 2 elements
b = a[a[:,0] > 1][:2]

问题 是当我有一个大数组时这非常慢(假设我只想切掉条件 2 的一小块)。这很容易完成,但我还没有找到将其放入单行代码的方法。

因此,是否有一种巧妙的方法可以在 one-liner 中高效地执行此操作?像这样:

b = a[a[:,0] > 1 and :2]

谢谢!

最佳答案

我想不出直接使用 numpy 的更快解决方案,但使用 numba 可能会做得更好:

from numba import autojit

def filtfunc(a):
    idx = []
    for ii in range(a.shape[0]):
        if (a[ii, 0] > 1):
            idx.append(ii)
            if (len(idx) == 2):
                break
    return a[idx]

jit_filter = autojit(filtfunc)

作为引用,以下是另外两个建议的解决方案:

def marco_filter(a):
    return a[a[:,0] > 1][:2]

def rico_filter(a):
    mask = a[:, 0] > 1
    where = np.where(mask)[0][:2]
    return a[where]

一些时间:

%%timeit a = np.random.random_integers(1, 12, (1000,1000))
marco_filter(a)
# 100 loops, best of 3: 11.6 ms per loop

%%timeit a = np.random.random_integers(1, 12, (1000,1000))
rico_filter(a)
# 10000 loops, best of 3: 44.8 µs per loop

%%timeit a = np.random.random_integers(1, 12, (1000,1000))
jit_filter(a)
# 10000 loops, best of 3: 30.7 µs per loop

关于python - 是否可以在 numpy 数组切片运算符中组合逻辑条件和限制条件,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/20693428/

相关文章:

python - 索引错误: shape mismatch: indexing arrays could not be broadcast together with shapes

python - 在循环条件中评估表达式

python - 迭代字典中的列表

javascript - 在运行时在javascript数组中添加对象

python - 为什么 NumPy 中的轴参数会改变?

python - 返回列表中最后一个非零元素的索引

python - 使用 Amazon s3 时 Django ImageField url 变慢

python - 如何在打印语句后取消换行符?

objective-c - 简单数组的问题 - Cocoa

python - 将文件读入由段落Python分隔的数组