python - 为什么 numpy 不会在非连续数组上短路?

标签 python numpy short-circuiting

考虑以下简单测试:

import numpy as np
from timeit import timeit

a = np.random.randint(0,2,1000000,bool)

让我们找到第一个 True

的索引
timeit(lambda:a.argmax(), number=1000)
# 0.000451055821031332

这相当快,因为​​ numpy 短路。

它也适用于连续的切片,

timeit(lambda:a[1:-1].argmax(), number=1000)
# 0.0006490410305559635

但似乎不是在非连续的上。我主要对找到最后一个 True 感兴趣:

timeit(lambda:a[::-1].argmax(), number=1000)
# 0.3737605109345168

UPDATE: My assumption that the observed slowdown was due to not short circuiting is inaccurate (thanks @Victor Ruiz). Indeed, in the worst-case scenario of an all False array

b=np.zeros_like(a)
timeit(lambda:b.argmax(), number=1000)
# 0.04321779008023441

we are still an order of magnitude faster than in the non-contiguous case. I'm ready to accept Victor's explanation that the actual culprit is a copy being made (timings of forcing a copy with .copy() are suggestive). Afterwards it doesn't really matter anymore whether short-circuiting happens or not.

但其他步长 != 1 会产生类似的行为。

timeit(lambda:a[::2].argmax(), number=1000)
# 0.19192566303536296

问题:为什么 numpy 在最后两个例子中没有短路 UPDATE 没有复制

而且,更重要的是:是否有解决方法,即强制 numpy 短路 UPDATE 而不制作副本在非连续数组上?

最佳答案

问题与使用strides时数组的内存对齐有关。 a[1:-1]a[::-1] 被认为在内存中对齐,但 a[::2] 不要:

a = np.random.randint(0,2,1000000,bool)

print(a[1:-1].flags.c_contiguous) # True
print(a[::-1].flags.c_contiguous) # False
print(a[::2].flags.c_contiguous) # False

这解释了为什么 np.argmaxa[::2] 上很慢(来自 ndarrays 上的文档):

Several algorithms in NumPy work on arbitrarily strided arrays. However, some algorithms require single-segment arrays. When an irregularly strided array is passed in to such algorithms, a copy is automatically made.

np.argmax(a[::2]) 正在制作数组的副本。因此,如果您执行 timeit(lambda: np.argmax(a[::2]), number=5000),您将计时数组 a

执行这个并比较这两个计时调用的结果:

print(timeit(lambda: np.argmax(a[::2]), number=5000))

b = a[::2].copy()
print(timeit(lambda: np.argmax(b), number=5000))

编辑: 深入研究numpy的C语言源代码,我发现了argmax函数的下划线实现,PyArray_ArgMax在某个时候调用 PyArray_ContiguousFromAny确保给定的输入数组在内存中对齐(C 风格)

然后,如果数组的 dtype 是 bool,它委托(delegate)给 BOOL_argmax功能。 查看其代码,似乎始终应用了短路。

总结

  • 为了避免被np.argmax复制,确保输入数组在内存中是连续的
  • 当数据类型为 bool 值时,始终应用短路。

关于python - 为什么 numpy 不会在非连续数组上短路?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57346182/

相关文章:

python - 使用 Xarray 迭代创建 DataArray 的最直接/紧凑的方法

python - 当Python中簇数为2时,我的图没有显示结果

Python:通过迭代加速矩阵坐标映射

Python bool 运算符的优先级规则

Oracle CASE 短路不能按组工作

用于 __str__ 和方法解析顺序的 Python Mixin

Python:将查询字符串分解为关联数组不起作用

python - 如果一个元素小于或大于某个值,如何删除 2D numpy 数组中的列

JavaScript 使用逗号进行短路变量赋值

Python Mock_requests : Can I use wildcards in the url parameter of the Mocker? 与pytest一起使用时如何实现url的模式匹配?