Numpy 的 ufunc
我们有一个 reduceat
在数组中的连续分区上运行它们的方法。所以不要写:
import numpy as np
a = np.array([4, 0, 6, 8, 0, 9, 8, 5, 4, 9])
split_at = [4, 5]
maxima = [max(subarray for subarray in np.split(a, split_at)]
我会写:
maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))
两者都将返回切片中的最大值 a[0:4]
, a[4:5]
, a[5:10]
, 是 [8, 0, 9]
.
我想要一个类似的函数来执行 argmax
,注意到我只想在每个分区中有一个单个最大索引:[3, 4, 5]
与上述 a
和 split_at
(尽管索引 5 和 9 都获得了最后一组中的最大值),正如
np.hstack([0, split_at]) + [np.argmax(subarray) for subarray in np.split(a, split_at)]
我将在下面发布一个可能的解决方案,但我希望看到一个无需在组上创建索引即可矢量化的解决方案。
最佳答案
此解决方案涉及在组(上例中的 [0, 0, 0, 0, 1, 2, 2, 2, 2, 2]
)上构建索引。
group_lengths = np.diff(np.hstack([0, split_at, len(a)]))
n_groups = len(group_lengths)
index = np.repeat(np.arange(n_groups), group_lengths)
然后我们可以使用:
maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))
all_argmax = np.flatnonzero(np.repeat(maxima, group_lengths) == a)
result = np.empty(len(group_lengths), dtype='i')
result[index[all_argmax[::-1]]] = all_argmax[::-1]
在 result
中得到 [3, 4, 5]
。 [::-1]
确保我们得到每个组中的 第一个 而不是最后一个 argmax。
这依赖于花式赋值中的最后一个索引决定赋值的事实,@seberg says one shouldn't rely on (使用 result = all_argmax[np.unique(index[all_argmax], return_index=True)[1]]
可以实现更安全的替代方案,它涉及对 len(maxima) 的排序~ n_groups
个元素)。
关于python - 在 numpy 中对分区索引分组 argmax/argmin,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/22124332/