Python:沿特定维度查找大于阈值的最大数组索引

标签 python arrays performance numpy

假设我有一个 4-D numpy 数组(例如:np.rand((x,y,z,t))),其维度对应于 X、Y、Z,和时间。

对于每个 X 和 Y 点,在每个时间步,我想找到 Z 中数据大于某个阈值 n 的最大索引。

所以我的最终结果应该是一个 X-by-Y-by-t 数组。 Z 列中没有大于阈值的值的实例应以 0 表示。

我可以逐个元素循环并构造一个新数组,但是我正在操作一个非常大的数组,并且需要很长时间。

最佳答案

不幸的是,按照 Python 内置函数的示例,numpy 并不容易获取 last 索引,尽管 first 很简单。不过,类似

def slow(arr, axis, threshold):
    return (arr > threshold).cumsum(axis=axis).argmax(axis=axis)

def fast(arr, axis, threshold):
    compare = (arr > threshold)
    reordered = compare.swapaxes(axis, -1)
    flipped = reordered[..., ::-1]
    first_above = flipped.argmax(axis=-1)
    last_above = flipped.shape[-1] - first_above - 1
    are_any_above = compare.any(axis=axis)
    # patch the no-matching-element found values
    patched = np.where(are_any_above, last_above, 0)
    return patched

给我

In [14]: arr = np.random.random((100,100,30,50))

In [15]: %timeit a = slow(arr, axis=2, threshold=0.75)
1 loop, best of 3: 248 ms per loop

In [16]: %timeit b = fast(arr, axis=2, threshold=0.75)
10 loops, best of 3: 50.9 ms per loop

In [17]: (slow(arr, axis=2, threshold=0.75) == fast(arr, axis=2, threshold=0.75)).all()
Out[17]: True

(可能有一种更巧妙的方法来进行翻转,但现在已经是一天的结束了,我的大脑正在关闭。:-)

关于Python:沿特定维度查找大于阈值的最大数组索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41515201/

相关文章:

c# - Crystal 报表 ExportToDisk 消耗大量时间

python - JPype 无法正确编译

python - 如何查询名称包含python列表中任何单词的模型?

python - Dash 中的布局和下拉菜单 - Python

jquery - 这些数组的正确名称

c# - 性能因 EF LINQ 连接而降低

performance - 是什么导致了这个过多的 "Composite Layers"、 "Recalculate Style"和 "Update Layer Tree"循环?

python - 克服空数组的 ValueError

javascript - 根据值重新排序数组

arrays - 有效二维数组的delta数据结构