假设我有一个 4-D numpy 数组(例如:np.rand((x,y,z,t))
),其维度对应于 X、Y、Z,和时间。
对于每个 X 和 Y 点,在每个时间步,我想找到 Z 中数据大于某个阈值 n 的最大索引。
所以我的最终结果应该是一个 X-by-Y-by-t 数组。 Z 列中没有大于阈值的值的实例应以 0 表示。
我可以逐个元素循环并构造一个新数组,但是我正在操作一个非常大的数组,并且需要很长时间。
最佳答案
不幸的是,按照 Python 内置函数的示例,numpy 并不容易获取 last 索引,尽管 first 很简单。不过,类似
def slow(arr, axis, threshold):
return (arr > threshold).cumsum(axis=axis).argmax(axis=axis)
def fast(arr, axis, threshold):
compare = (arr > threshold)
reordered = compare.swapaxes(axis, -1)
flipped = reordered[..., ::-1]
first_above = flipped.argmax(axis=-1)
last_above = flipped.shape[-1] - first_above - 1
are_any_above = compare.any(axis=axis)
# patch the no-matching-element found values
patched = np.where(are_any_above, last_above, 0)
return patched
给我
In [14]: arr = np.random.random((100,100,30,50))
In [15]: %timeit a = slow(arr, axis=2, threshold=0.75)
1 loop, best of 3: 248 ms per loop
In [16]: %timeit b = fast(arr, axis=2, threshold=0.75)
10 loops, best of 3: 50.9 ms per loop
In [17]: (slow(arr, axis=2, threshold=0.75) == fast(arr, axis=2, threshold=0.75)).all()
Out[17]: True
(可能有一种更巧妙的方法来进行翻转,但现在已经是一天的结束了,我的大脑正在关闭。:-)
关于Python:沿特定维度查找大于阈值的最大数组索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41515201/